mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -1,188 +1,238 @@
|
|||||||
-- You can change `dbgpt` to your actual metadata database name in your `.env` file
|
-- You can change `dbgpt` to your actual metadata database name in your `.env` file
|
||||||
-- eg. `LOCAL_DB_NAME=dbgpt`
|
-- eg. `LOCAL_DB_NAME=dbgpt`
|
||||||
|
|
||||||
CREATE DATABASE IF NOT EXISTS dbgpt;
|
CREATE
|
||||||
|
DATABASE IF NOT EXISTS dbgpt;
|
||||||
use dbgpt;
|
use dbgpt;
|
||||||
|
|
||||||
-- For alembic migration tool
|
-- For alembic migration tool
|
||||||
CREATE TABLE `alembic_version` (
|
CREATE TABLE IF NOT EXISTS `alembic_version`
|
||||||
version_num VARCHAR(32) NOT NULL,
|
(
|
||||||
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
|
version_num VARCHAR(32) NOT NULL,
|
||||||
|
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE `knowledge_space` (
|
CREATE TABLE IF NOT EXISTS `knowledge_space`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
(
|
||||||
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||||
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
|
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
|
||||||
`desc` varchar(500) NOT NULL COMMENT 'description',
|
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
|
||||||
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
|
`desc` varchar(500) NOT NULL COMMENT 'description',
|
||||||
`context` TEXT DEFAULT NULL COMMENT 'context argument',
|
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
|
||||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
`context` TEXT DEFAULT NULL COMMENT 'context argument',
|
||||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
PRIMARY KEY (`id`),
|
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
|
PRIMARY KEY (`id`),
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table';
|
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table';
|
||||||
|
|
||||||
CREATE TABLE `knowledge_document` (
|
CREATE TABLE IF NOT EXISTS `knowledge_document`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
(
|
||||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||||
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
|
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||||
`chunk_size` int NOT NULL COMMENT 'chunk size',
|
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
|
||||||
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
|
`chunk_size` int NOT NULL COMMENT 'chunk size',
|
||||||
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
|
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
|
||||||
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
|
||||||
`result` TEXT NULL COMMENT 'knowledge content',
|
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
||||||
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
`result` TEXT NULL COMMENT 'knowledge content',
|
||||||
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
|
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
||||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
|
||||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
PRIMARY KEY (`id`),
|
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
|
PRIMARY KEY (`id`),
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table';
|
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table';
|
||||||
|
|
||||||
CREATE TABLE `document_chunk` (
|
CREATE TABLE IF NOT EXISTS `document_chunk`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
(
|
||||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||||
`document_id` int NOT NULL COMMENT 'document parent id',
|
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||||
`content` longtext NOT NULL COMMENT 'chunk content',
|
`document_id` int NOT NULL COMMENT 'document parent id',
|
||||||
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
|
`content` longtext NOT NULL COMMENT 'chunk content',
|
||||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
|
||||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
PRIMARY KEY (`id`),
|
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
|
PRIMARY KEY (`id`),
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail';
|
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail';
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE `connect_config` (
|
CREATE TABLE IF NOT EXISTS `connect_config`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
(
|
||||||
`db_type` varchar(255) NOT NULL COMMENT 'db type',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||||
`db_name` varchar(255) NOT NULL COMMENT 'db name',
|
`db_type` varchar(255) NOT NULL COMMENT 'db type',
|
||||||
`db_path` varchar(255) DEFAULT NULL COMMENT 'file db path',
|
`db_name` varchar(255) NOT NULL COMMENT 'db name',
|
||||||
`db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)',
|
`db_path` varchar(255) DEFAULT NULL COMMENT 'file db path',
|
||||||
`db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)',
|
`db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)',
|
||||||
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
|
`db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)',
|
||||||
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
|
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
|
||||||
`comment` text COMMENT 'db comment',
|
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
|
||||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
`comment` text COMMENT 'db comment',
|
||||||
PRIMARY KEY (`id`),
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
UNIQUE KEY `uk_db` (`db_name`),
|
PRIMARY KEY (`id`),
|
||||||
KEY `idx_q_db_type` (`db_type`)
|
UNIQUE KEY `uk_db` (`db_name`),
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi';
|
KEY `idx_q_db_type` (`db_type`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi';
|
||||||
|
|
||||||
CREATE TABLE `chat_history` (
|
CREATE TABLE IF NOT EXISTS `chat_history`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
(
|
||||||
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||||
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
|
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||||
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
|
||||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
||||||
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
||||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
||||||
PRIMARY KEY (`id`)
|
`message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
|
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
|
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
|
UNIQUE KEY `conv_uid` (`conv_uid`),
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
|
||||||
|
|
||||||
CREATE TABLE `chat_feed_back` (
|
CREATE TABLE IF NOT EXISTS `chat_history_message`
|
||||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
(
|
||||||
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||||
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation',
|
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||||
`score` int(1) DEFAULT NULL COMMENT 'Score of user',
|
`index` int NOT NULL COMMENT 'Message index',
|
||||||
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category',
|
`round_index` int NOT NULL COMMENT 'Round of conversation',
|
||||||
`question` longtext DEFAULT NULL COMMENT 'User question',
|
`message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format',
|
||||||
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
|
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
|
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
UNIQUE KEY `message_uid_index` (`conv_uid`, `index`),
|
||||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
PRIMARY KEY (`id`)
|
||||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message';
|
||||||
PRIMARY KEY (`id`),
|
|
||||||
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
CREATE TABLE IF NOT EXISTS `chat_feed_back`
|
||||||
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
(
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table';
|
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||||
|
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID',
|
||||||
|
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation',
|
||||||
|
`score` int(1) DEFAULT NULL COMMENT 'Score of user',
|
||||||
|
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category',
|
||||||
|
`question` longtext DEFAULT NULL COMMENT 'User question',
|
||||||
|
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
|
||||||
|
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
|
||||||
|
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||||
|
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
|
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
|
PRIMARY KEY (`id`),
|
||||||
|
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
||||||
|
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table';
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE `my_plugin` (
|
CREATE TABLE IF NOT EXISTS `my_plugin`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
(
|
||||||
`tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||||
`user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code',
|
`tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant',
|
||||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name',
|
`user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code',
|
||||||
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name',
|
||||||
`file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name',
|
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
||||||
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
`file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name',
|
||||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
||||||
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
|
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||||
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
|
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
|
||||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
|
||||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
PRIMARY KEY (`id`),
|
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
|
||||||
UNIQUE KEY `name` (`name`)
|
PRIMARY KEY (`id`),
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table';
|
UNIQUE KEY `name` (`name`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table';
|
||||||
|
|
||||||
CREATE TABLE `plugin_hub` (
|
CREATE TABLE IF NOT EXISTS `plugin_hub`
|
||||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
(
|
||||||
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||||
`description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description',
|
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
||||||
`author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author',
|
`description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description',
|
||||||
`email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email',
|
`author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author',
|
||||||
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
`email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email',
|
||||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
||||||
`storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel',
|
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||||
`storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url',
|
`storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel',
|
||||||
`download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param',
|
`storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url',
|
||||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time',
|
`download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param',
|
||||||
`installed` int DEFAULT NULL COMMENT 'plugin already installed count',
|
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time',
|
||||||
PRIMARY KEY (`id`),
|
`installed` int DEFAULT NULL COMMENT 'plugin already installed count',
|
||||||
UNIQUE KEY `name` (`name`)
|
PRIMARY KEY (`id`),
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table';
|
UNIQUE KEY `name` (`name`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table';
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE `prompt_manage` (
|
CREATE TABLE IF NOT EXISTS `prompt_manage`
|
||||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
(
|
||||||
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
|
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||||
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
|
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
|
||||||
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
|
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
|
||||||
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
|
||||||
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
||||||
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
||||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
||||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
PRIMARY KEY (`id`),
|
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
|
PRIMARY KEY (`id`),
|
||||||
KEY `gmt_created_idx` (`gmt_created`)
|
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
|
KEY `gmt_created_idx` (`gmt_created`)
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CREATE DATABASE EXAMPLE_1;
|
CREATE
|
||||||
|
DATABASE IF NOT EXISTS EXAMPLE_1;
|
||||||
use EXAMPLE_1;
|
use EXAMPLE_1;
|
||||||
CREATE TABLE `users` (
|
CREATE TABLE IF NOT EXISTS `users`
|
||||||
`id` int NOT NULL AUTO_INCREMENT,
|
(
|
||||||
`username` varchar(50) NOT NULL COMMENT '用户名',
|
`id` int NOT NULL AUTO_INCREMENT,
|
||||||
`password` varchar(50) NOT NULL COMMENT '密码',
|
`username` varchar(50) NOT NULL COMMENT '用户名',
|
||||||
`email` varchar(50) NOT NULL COMMENT '邮箱',
|
`password` varchar(50) NOT NULL COMMENT '密码',
|
||||||
`phone` varchar(20) DEFAULT NULL COMMENT '电话',
|
`email` varchar(50) NOT NULL COMMENT '邮箱',
|
||||||
PRIMARY KEY (`id`),
|
`phone` varchar(20) DEFAULT NULL COMMENT '电话',
|
||||||
KEY `idx_username` (`username`) COMMENT '索引:按用户名查询'
|
PRIMARY KEY (`id`),
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=101 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表';
|
KEY `idx_username` (`username`) COMMENT '索引:按用户名查询'
|
||||||
|
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表';
|
||||||
|
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902');
|
VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904');
|
VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906');
|
VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908');
|
VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900');
|
VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902');
|
VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904');
|
VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906');
|
VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908');
|
VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909');
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909');
|
INSERT INTO users (username, password, email, phone)
|
||||||
INSERT INTO users (username, password, email, phone) VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900');
|
VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909');
|
||||||
|
INSERT INTO users (username, password, email, phone)
|
||||||
|
VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900');
|
@@ -182,6 +182,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||||
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
|
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
|
||||||
|
self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20))
|
||||||
|
|
||||||
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")
|
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")
|
||||||
|
|
||||||
|
@@ -2,16 +2,10 @@ from datetime import datetime
|
|||||||
from sqlalchemy import Column, Integer, String, DateTime, func
|
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||||
from sqlalchemy import UniqueConstraint
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MyPluginEntity(Base):
|
class MyPluginEntity(Model):
|
||||||
__tablename__ = "my_plugin"
|
__tablename__ = "my_plugin"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -39,16 +33,8 @@ class MyPluginEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MyPluginDao(BaseDao[MyPluginEntity]):
|
class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def add(self, engity: MyPluginEntity):
|
def add(self, engity: MyPluginEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
my_plugin = MyPluginEntity(
|
my_plugin = MyPluginEntity(
|
||||||
tenant=engity.tenant,
|
tenant=engity.tenant,
|
||||||
user_code=engity.user_code,
|
user_code=engity.user_code,
|
||||||
@@ -68,13 +54,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
return id
|
return id
|
||||||
|
|
||||||
def update(self, entity: MyPluginEntity):
|
def update(self, entity: MyPluginEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
updated = session.merge(entity)
|
updated = session.merge(entity)
|
||||||
session.commit()
|
session.commit()
|
||||||
return updated.id
|
return updated.id
|
||||||
|
|
||||||
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
my_plugins = session.query(MyPluginEntity)
|
my_plugins = session.query(MyPluginEntity)
|
||||||
if user:
|
if user:
|
||||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||||
@@ -83,7 +69,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
my_plugins = session.query(MyPluginEntity)
|
my_plugins = session.query(MyPluginEntity)
|
||||||
if user:
|
if user:
|
||||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||||
@@ -93,7 +79,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
my_plugins = session.query(MyPluginEntity)
|
my_plugins = session.query(MyPluginEntity)
|
||||||
all_count = my_plugins.count()
|
all_count = my_plugins.count()
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
@@ -122,7 +108,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
return result, total_pages, all_count
|
return result, total_pages, all_count
|
||||||
|
|
||||||
def count(self, query: MyPluginEntity):
|
def count(self, query: MyPluginEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||||
@@ -143,7 +129,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
def delete(self, plugin_id: int):
|
def delete(self, plugin_id: int):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if plugin_id is None:
|
if plugin_id is None:
|
||||||
raise Exception("plugin_id is None")
|
raise Exception("plugin_id is None")
|
||||||
query = MyPluginEntity(id=plugin_id)
|
query = MyPluginEntity(id=plugin_id)
|
||||||
|
@@ -3,19 +3,13 @@ import pytz
|
|||||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
|
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
|
||||||
from sqlalchemy import UniqueConstraint
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||||
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
||||||
|
|
||||||
|
|
||||||
class PluginHubEntity(Base):
|
class PluginHubEntity(Model):
|
||||||
__tablename__ = "plugin_hub"
|
__tablename__ = "plugin_hub"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -43,16 +37,8 @@ class PluginHubEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class PluginHubDao(BaseDao[PluginHubEntity]):
|
class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def add(self, engity: PluginHubEntity):
|
def add(self, engity: PluginHubEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
timezone = pytz.timezone("Asia/Shanghai")
|
timezone = pytz.timezone("Asia/Shanghai")
|
||||||
plugin_hub = PluginHubEntity(
|
plugin_hub = PluginHubEntity(
|
||||||
name=engity.name,
|
name=engity.name,
|
||||||
@@ -71,7 +57,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
return id
|
return id
|
||||||
|
|
||||||
def update(self, entity: PluginHubEntity):
|
def update(self, entity: PluginHubEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
try:
|
try:
|
||||||
updated = session.merge(entity)
|
updated = session.merge(entity)
|
||||||
session.commit()
|
session.commit()
|
||||||
@@ -82,7 +68,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
def list(
|
def list(
|
||||||
self, query: PluginHubEntity, page=1, page_size=20
|
self, query: PluginHubEntity, page=1, page_size=20
|
||||||
) -> list[PluginHubEntity]:
|
) -> list[PluginHubEntity]:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
plugin_hubs = session.query(PluginHubEntity)
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
all_count = plugin_hubs.count()
|
all_count = plugin_hubs.count()
|
||||||
|
|
||||||
@@ -111,7 +97,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
return result, total_pages, all_count
|
return result, total_pages, all_count
|
||||||
|
|
||||||
def get_by_storage_url(self, storage_url):
|
def get_by_storage_url(self, storage_url):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
plugin_hubs = session.query(PluginHubEntity)
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||||
result = plugin_hubs.all()
|
result = plugin_hubs.all()
|
||||||
@@ -119,7 +105,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
plugin_hubs = session.query(PluginHubEntity)
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
||||||
result = plugin_hubs.first()
|
result = plugin_hubs.first()
|
||||||
@@ -127,7 +113,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def count(self, query: PluginHubEntity):
|
def count(self, query: PluginHubEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||||
@@ -146,7 +132,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
def delete(self, plugin_id: int):
|
def delete(self, plugin_id: int):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if plugin_id is None:
|
if plugin_id is None:
|
||||||
raise Exception("plugin_id is None")
|
raise Exception("plugin_id is None")
|
||||||
plugin_hubs = session.query(PluginHubEntity)
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
|
@@ -59,18 +59,12 @@ class AgentHub:
|
|||||||
else:
|
else:
|
||||||
my_plugin_entity.user_code = Default_User
|
my_plugin_entity.user_code = Default_User
|
||||||
|
|
||||||
with self.hub_dao.get_session() as session:
|
with self.hub_dao.session() as session:
|
||||||
try:
|
if my_plugin_entity.id is None:
|
||||||
if my_plugin_entity.id is None:
|
session.add(my_plugin_entity)
|
||||||
session.add(my_plugin_entity)
|
else:
|
||||||
else:
|
session.merge(my_plugin_entity)
|
||||||
session.merge(my_plugin_entity)
|
session.merge(plugin_entity)
|
||||||
session.merge(plugin_entity)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("install merge roll back!" + str(e))
|
|
||||||
session.rollback()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("install pluguin exception!", e)
|
logger.error("install pluguin exception!", e)
|
||||||
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||||
@@ -87,19 +81,15 @@ class AgentHub:
|
|||||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
||||||
if plugin_entity is not None:
|
if plugin_entity is not None:
|
||||||
plugin_entity.installed = plugin_entity.installed - 1
|
plugin_entity.installed = plugin_entity.installed - 1
|
||||||
with self.hub_dao.get_session() as session:
|
with self.hub_dao.session() as session:
|
||||||
try:
|
my_plugin_q = session.query(MyPluginEntity).filter(
|
||||||
my_plugin_q = session.query(MyPluginEntity).filter(
|
MyPluginEntity.name == plugin_name
|
||||||
MyPluginEntity.name == plugin_name
|
)
|
||||||
)
|
if user:
|
||||||
if user:
|
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
my_plugin_q.delete()
|
||||||
my_plugin_q.delete()
|
if plugin_entity is not None:
|
||||||
if plugin_entity is not None:
|
session.merge(plugin_entity)
|
||||||
session.merge(plugin_entity)
|
|
||||||
session.commit()
|
|
||||||
except:
|
|
||||||
session.rollback()
|
|
||||||
|
|
||||||
if plugin_entity is not None:
|
if plugin_entity is not None:
|
||||||
# delete package file if not use
|
# delete package file if not use
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
|
from typing import Optional
|
||||||
import click
|
import click
|
||||||
import os
|
import os
|
||||||
|
import functools
|
||||||
from dbgpt.app.base import WebServerParameters
|
from dbgpt.app.base import WebServerParameters
|
||||||
from dbgpt.configs.model_config import LOGDIR
|
from dbgpt.configs.model_config import LOGDIR
|
||||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||||
@@ -34,3 +36,241 @@ def stop_webserver(port: int):
|
|||||||
|
|
||||||
def _stop_all_dbgpt_server():
|
def _stop_all_dbgpt_server():
|
||||||
_stop_service("webserver", "WebServer")
|
_stop_service("webserver", "WebServer")
|
||||||
|
|
||||||
|
|
||||||
|
@click.group("migration")
|
||||||
|
def migration():
|
||||||
|
"""Manage database migration"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def add_migration_options(func):
|
||||||
|
@click.option(
|
||||||
|
"--alembic_ini_path",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Alembic ini path, if not set, use 'pilot/meta_data/alembic.ini'",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--script_location",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Alembic script location, if not set, use 'pilot/meta_data/alembic'",
|
||||||
|
)
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
@click.option(
|
||||||
|
"-m",
|
||||||
|
"--message",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default="Init migration",
|
||||||
|
show_default=True,
|
||||||
|
help="The message for create migration repository",
|
||||||
|
)
|
||||||
|
def init(alembic_ini_path: str, script_location: str, message: str):
|
||||||
|
"""Initialize database migration repository"""
|
||||||
|
from dbgpt.util._db_migration_utils import create_migration_script
|
||||||
|
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
create_migration_script(alembic_cfg, db_manager.engine, message)
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
@click.option(
|
||||||
|
"-m",
|
||||||
|
"--message",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default="New migration",
|
||||||
|
show_default=True,
|
||||||
|
help="The message for migration script",
|
||||||
|
)
|
||||||
|
def migrate(alembic_ini_path: str, script_location: str, message: str):
|
||||||
|
"""Create migration script"""
|
||||||
|
from dbgpt.util._db_migration_utils import create_migration_script
|
||||||
|
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
create_migration_script(alembic_cfg, db_manager.engine, message)
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
def upgrade(alembic_ini_path: str, script_location: str):
|
||||||
|
"""Upgrade database to target version"""
|
||||||
|
from dbgpt.util._db_migration_utils import upgrade_database
|
||||||
|
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
upgrade_database(alembic_cfg, db_manager.engine)
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
@click.option(
|
||||||
|
"-y",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Confirm to downgrade database",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-r",
|
||||||
|
"--revision",
|
||||||
|
default="-1",
|
||||||
|
show_default=True,
|
||||||
|
help="Revision to downgrade to",
|
||||||
|
)
|
||||||
|
def downgrade(alembic_ini_path: str, script_location: str, y: bool, revision: str):
|
||||||
|
"""Downgrade database to target version"""
|
||||||
|
from dbgpt.util._db_migration_utils import downgrade_database
|
||||||
|
|
||||||
|
if not y:
|
||||||
|
click.confirm("Are you sure you want to downgrade the database?", abort=True)
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
downgrade_database(alembic_cfg, db_manager.engine, revision)
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
@click.option(
|
||||||
|
"--drop_all_tables",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Drop all tables",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-y",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Confirm to clean migration data",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--confirm_drop_all_tables",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Confirm to drop all tables",
|
||||||
|
)
|
||||||
|
def clean(
|
||||||
|
alembic_ini_path: str,
|
||||||
|
script_location: str,
|
||||||
|
drop_all_tables: bool,
|
||||||
|
y: bool,
|
||||||
|
confirm_drop_all_tables: bool,
|
||||||
|
):
|
||||||
|
"""Clean Alembic migration scripts and history"""
|
||||||
|
from dbgpt.util._db_migration_utils import clean_alembic_migration
|
||||||
|
|
||||||
|
if not y:
|
||||||
|
click.confirm(
|
||||||
|
"Are you sure clean alembic migration scripts and history?", abort=True
|
||||||
|
)
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
clean_alembic_migration(alembic_cfg, db_manager.engine)
|
||||||
|
if drop_all_tables:
|
||||||
|
if not confirm_drop_all_tables:
|
||||||
|
click.confirm("\nAre you sure drop all tables?", abort=True)
|
||||||
|
with db_manager.engine.connect() as connection:
|
||||||
|
for tbl in reversed(db_manager.Model.metadata.sorted_tables):
|
||||||
|
print(f"Drop table {tbl.name}")
|
||||||
|
connection.execute(tbl.delete())
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
def list(alembic_ini_path: str, script_location: str):
|
||||||
|
"""List all versions in the migration history, marking the current one"""
|
||||||
|
from alembic.script import ScriptDirectory
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
|
||||||
|
# Set up Alembic environment and script directory
|
||||||
|
script = ScriptDirectory.from_config(alembic_cfg)
|
||||||
|
|
||||||
|
# Get current revision
|
||||||
|
def get_current_revision():
|
||||||
|
with db_manager.engine.connect() as connection:
|
||||||
|
context = MigrationContext.configure(connection)
|
||||||
|
return context.get_current_revision()
|
||||||
|
|
||||||
|
current_rev = get_current_revision()
|
||||||
|
|
||||||
|
# List all revisions and mark the current one
|
||||||
|
for revision in script.walk_revisions():
|
||||||
|
current_marker = "(current)" if revision.revision == current_rev else ""
|
||||||
|
print(f"{revision.revision} {current_marker}: {revision.doc}")
|
||||||
|
|
||||||
|
|
||||||
|
@migration.command()
|
||||||
|
@add_migration_options
|
||||||
|
@click.argument("revision", required=True)
|
||||||
|
def show(alembic_ini_path: str, script_location: str, revision: str):
|
||||||
|
"""Show the migration script for a specific version."""
|
||||||
|
from alembic.script import ScriptDirectory
|
||||||
|
|
||||||
|
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||||
|
|
||||||
|
script = ScriptDirectory.from_config(alembic_cfg)
|
||||||
|
|
||||||
|
rev = script.get_revision(revision)
|
||||||
|
if rev is None:
|
||||||
|
print(f"Revision {revision} not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the migration script file
|
||||||
|
script_files = os.listdir(os.path.join(script.dir, "versions"))
|
||||||
|
script_file = next((f for f in script_files if f.startswith(revision)), None)
|
||||||
|
|
||||||
|
if script_file is None:
|
||||||
|
print(f"Migration script for revision {revision} not found.")
|
||||||
|
return
|
||||||
|
# Print the migration script
|
||||||
|
script_file_path = os.path.join(script.dir, "versions", script_file)
|
||||||
|
print(f"Migration script for revision {revision}: {script_file_path}")
|
||||||
|
try:
|
||||||
|
with open(script_file_path, "r") as file:
|
||||||
|
print(file.read())
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Migration script {script_file_path} not found.")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_migration_config(
|
||||||
|
alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None
|
||||||
|
):
|
||||||
|
from dbgpt.storage.metadata.db_manager import db as db_manager
|
||||||
|
from dbgpt.util._db_migration_utils import create_alembic_config
|
||||||
|
|
||||||
|
# Must import dbgpt_server for initialize db metadata
|
||||||
|
from dbgpt.app.dbgpt_server import initialize_app as _
|
||||||
|
from dbgpt.app.base import _initialize_db
|
||||||
|
|
||||||
|
# initialize db
|
||||||
|
default_meta_data_path = _initialize_db()
|
||||||
|
alembic_cfg = create_alembic_config(
|
||||||
|
default_meta_data_path,
|
||||||
|
db_manager.engine,
|
||||||
|
db_manager.Model,
|
||||||
|
db_manager.session(),
|
||||||
|
alembic_ini_path,
|
||||||
|
script_location,
|
||||||
|
)
|
||||||
|
return alembic_cfg, db_manager
|
||||||
|
@@ -8,7 +8,8 @@ from dataclasses import dataclass, field
|
|||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.util.parameter_utils import BaseParameters
|
from dbgpt.util.parameter_utils import BaseParameters
|
||||||
from dbgpt.storage.metadata.meta_data import ddl_init_and_upgrade
|
|
||||||
|
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
@@ -36,8 +37,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
|
|||||||
# init config
|
# init config
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
cfg.SYSTEM_APP = system_app
|
cfg.SYSTEM_APP = system_app
|
||||||
|
# Initialize db storage first
|
||||||
ddl_init_and_upgrade(param.disable_alembic_upgrade)
|
_initialize_db_storage(param)
|
||||||
|
|
||||||
# load_native_plugins(cfg)
|
# load_native_plugins(cfg)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
@@ -83,6 +84,46 @@ def _create_model_start_listener(system_app: SystemApp):
|
|||||||
return startup_event
|
return startup_event
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_db_storage(param: "WebServerParameters"):
|
||||||
|
"""Initialize the db storage.
|
||||||
|
|
||||||
|
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||||
|
"""
|
||||||
|
default_meta_data_path = _initialize_db(
|
||||||
|
try_to_create_db=not param.disable_alembic_upgrade
|
||||||
|
)
|
||||||
|
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||||
|
"""Initialize the database
|
||||||
|
|
||||||
|
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||||
|
"""
|
||||||
|
from dbgpt.configs.model_config import PILOT_PATH
|
||||||
|
from dbgpt.storage.metadata.db_manager import initialize_db
|
||||||
|
from urllib.parse import quote_plus as urlquote, quote
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
db_name = CFG.LOCAL_DB_NAME
|
||||||
|
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
|
||||||
|
os.makedirs(default_meta_data_path, exist_ok=True)
|
||||||
|
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||||
|
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
|
||||||
|
else:
|
||||||
|
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
|
||||||
|
db_url = f"sqlite:///{sqlite_db_path}"
|
||||||
|
engine_args = {
|
||||||
|
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
|
||||||
|
"max_overflow": CFG.LOCAL_DB_POOL_OVERFLOW,
|
||||||
|
"pool_timeout": 30,
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
"pool_pre_ping": True,
|
||||||
|
}
|
||||||
|
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
|
||||||
|
return default_meta_data_path
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebServerParameters(BaseParameters):
|
class WebServerParameters(BaseParameters):
|
||||||
host: Optional[str] = field(
|
host: Optional[str] = field(
|
||||||
|
@@ -13,7 +13,6 @@ from dbgpt.app.base import WebServerParameters
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
@@ -3,19 +3,13 @@ from typing import List
|
|||||||
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunkEntity(Base):
|
class DocumentChunkEntity(Model):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -35,16 +29,8 @@ class DocumentChunkEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class DocumentChunkDao(BaseDao):
|
class DocumentChunkDao(BaseDao):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_documents_chunks(self, documents: List):
|
def create_documents_chunks(self, documents: List):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
docs = [
|
docs = [
|
||||||
DocumentChunkEntity(
|
DocumentChunkEntity(
|
||||||
doc_name=document.doc_name,
|
doc_name=document.doc_name,
|
||||||
@@ -64,7 +50,7 @@ class DocumentChunkDao(BaseDao):
|
|||||||
def get_document_chunks(
|
def get_document_chunks(
|
||||||
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
|
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
|
||||||
):
|
):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
document_chunks = session.query(DocumentChunkEntity)
|
document_chunks = session.query(DocumentChunkEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||||
@@ -102,7 +88,7 @@ class DocumentChunkDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||||
@@ -127,7 +113,7 @@ class DocumentChunkDao(BaseDao):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
def delete(self, document_id: int):
|
def delete(self, document_id: int):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if document_id is None:
|
if document_id is None:
|
||||||
raise Exception("document_id is None")
|
raise Exception("document_id is None")
|
||||||
query = DocumentChunkEntity(document_id=document_id)
|
query = DocumentChunkEntity(document_id=document_id)
|
||||||
|
@@ -2,19 +2,13 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeDocumentEntity(Base):
|
class KnowledgeDocumentEntity(Model):
|
||||||
__tablename__ = "knowledge_document"
|
__tablename__ = "knowledge_document"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -39,16 +33,8 @@ class KnowledgeDocumentEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeDocumentDao(BaseDao):
|
class KnowledgeDocumentDao(BaseDao):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
knowledge_document = KnowledgeDocumentEntity(
|
knowledge_document = KnowledgeDocumentEntity(
|
||||||
doc_name=document.doc_name,
|
doc_name=document.doc_name,
|
||||||
doc_type=document.doc_type,
|
doc_type=document.doc_type,
|
||||||
@@ -69,7 +55,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return doc_id
|
return doc_id
|
||||||
|
|
||||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
print(f"current session:{session}")
|
print(f"current session:{session}")
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
@@ -104,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_documents(self, query):
|
def get_documents(self, query):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
print(f"current session:{session}")
|
print(f"current session:{session}")
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
@@ -136,7 +122,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_knowledge_documents_count_bulk(self, space_names):
|
def get_knowledge_documents_count_bulk(self, space_names):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
"""
|
"""
|
||||||
Perform a batch query to count the number of documents for each knowledge space.
|
Perform a batch query to count the number of documents for each knowledge space.
|
||||||
|
|
||||||
@@ -161,7 +147,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return docs_count
|
return docs_count
|
||||||
|
|
||||||
def get_knowledge_documents_count(self, query):
|
def get_knowledge_documents_count(self, query):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(
|
knowledge_documents = knowledge_documents.filter(
|
||||||
@@ -188,14 +174,14 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
updated_space = session.merge(document)
|
updated_space = session.merge(document)
|
||||||
session.commit()
|
session.commit()
|
||||||
return updated_space.id
|
return updated_space.id
|
||||||
|
|
||||||
#
|
#
|
||||||
def delete(self, query: KnowledgeDocumentEntity):
|
def delete(self, query: KnowledgeDocumentEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
@@ -2,20 +2,14 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSpaceEntity(Base):
|
class KnowledgeSpaceEntity(Model):
|
||||||
__tablename__ = "knowledge_space"
|
__tablename__ = "knowledge_space"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -35,16 +29,8 @@ class KnowledgeSpaceEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeSpaceDao(BaseDao):
|
class KnowledgeSpaceDao(BaseDao):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
knowledge_space = KnowledgeSpaceEntity(
|
knowledge_space = KnowledgeSpaceEntity(
|
||||||
name=space.name,
|
name=space.name,
|
||||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||||
@@ -58,7 +44,7 @@ class KnowledgeSpaceDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
@@ -97,14 +83,14 @@ class KnowledgeSpaceDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
session.merge(space)
|
session.merge(space)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if space:
|
if space:
|
||||||
session.delete(space)
|
session.delete(space)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@@ -2,17 +2,12 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||||
|
|
||||||
|
|
||||||
class ChatFeedBackEntity(Base):
|
class ChatFeedBackEntity(Model):
|
||||||
__tablename__ = "chat_feed_back"
|
__tablename__ = "chat_feed_back"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -39,18 +34,10 @@ class ChatFeedBackEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ChatFeedBackDao(BaseDao):
|
class ChatFeedBackDao(BaseDao):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||||
# Todo: We need to have user information first.
|
# Todo: We need to have user information first.
|
||||||
|
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
chat_feed_back = ChatFeedBackEntity(
|
chat_feed_back = ChatFeedBackEntity(
|
||||||
conv_uid=feed_back.conv_uid,
|
conv_uid=feed_back.conv_uid,
|
||||||
conv_index=feed_back.conv_index,
|
conv_index=feed_back.conv_index,
|
||||||
@@ -84,7 +71,7 @@ class ChatFeedBackDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
result = (
|
result = (
|
||||||
session.query(ChatFeedBackEntity)
|
session.query(ChatFeedBackEntity)
|
||||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||||
|
@@ -2,13 +2,8 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
|
|
||||||
from dbgpt.app.prompt.request.request import PromptManageRequest
|
from dbgpt.app.prompt.request.request import PromptManageRequest
|
||||||
@@ -16,7 +11,7 @@ from dbgpt.app.prompt.request.request import PromptManageRequest
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class PromptManageEntity(Base):
|
class PromptManageEntity(Model):
|
||||||
__tablename__ = "prompt_manage"
|
__tablename__ = "prompt_manage"
|
||||||
__table_args__ = {
|
__table_args__ = {
|
||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
@@ -38,16 +33,8 @@ class PromptManageEntity(Base):
|
|||||||
|
|
||||||
|
|
||||||
class PromptManageDao(BaseDao):
|
class PromptManageDao(BaseDao):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_prompt(self, prompt: PromptManageRequest):
|
def create_prompt(self, prompt: PromptManageRequest):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
prompt_manage = PromptManageEntity(
|
prompt_manage = PromptManageEntity(
|
||||||
chat_scene=prompt.chat_scene,
|
chat_scene=prompt.chat_scene,
|
||||||
sub_chat_scene=prompt.sub_chat_scene,
|
sub_chat_scene=prompt.sub_chat_scene,
|
||||||
@@ -64,7 +51,7 @@ class PromptManageDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_prompts(self, query: PromptManageEntity):
|
def get_prompts(self, query: PromptManageEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
prompts = session.query(PromptManageEntity)
|
prompts = session.query(PromptManageEntity)
|
||||||
if query.chat_scene is not None:
|
if query.chat_scene is not None:
|
||||||
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||||
@@ -93,13 +80,13 @@ class PromptManageDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def update_prompt(self, prompt: PromptManageEntity):
|
def update_prompt(self, prompt: PromptManageEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
session.merge(prompt)
|
session.merge(prompt)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def delete_prompt(self, prompt: PromptManageEntity):
|
def delete_prompt(self, prompt: PromptManageEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if prompt:
|
if prompt:
|
||||||
session.delete(prompt)
|
session.delete(prompt)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@@ -146,7 +146,9 @@ class BaseChat(ABC):
|
|||||||
input_values = await self.generate_input_values()
|
input_values = await self.generate_input_values()
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(
|
||||||
|
self.current_user_input, check_duplicate_type=True
|
||||||
|
)
|
||||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||||
"%Y-%m-%d %H:%M:%S"
|
"%Y-%m-%d %H:%M:%S"
|
||||||
)
|
)
|
||||||
@@ -221,7 +223,7 @@ class BaseChat(ABC):
|
|||||||
view_msg = self.stream_plugin_call(msg)
|
view_msg = self.stream_plugin_call(msg)
|
||||||
view_msg = view_msg.replace("\n", "\\n")
|
view_msg = view_msg.replace("\n", "\\n")
|
||||||
yield view_msg
|
yield view_msg
|
||||||
self.current_message.add_ai_message(msg)
|
self.current_message.add_ai_message(msg, update_if_exist=True)
|
||||||
view_msg = self.stream_call_reinforce_fn(view_msg)
|
view_msg = self.stream_call_reinforce_fn(view_msg)
|
||||||
self.current_message.add_view_message(view_msg)
|
self.current_message.add_view_message(view_msg)
|
||||||
span.end()
|
span.end()
|
||||||
@@ -257,7 +259,7 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
### model result deal
|
### model result deal
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
|
||||||
prompt_define_response = (
|
prompt_define_response = (
|
||||||
self.prompt_template.output_parser.parse_prompt_response(
|
self.prompt_template.output_parser.parse_prompt_response(
|
||||||
ai_response_text
|
ai_response_text
|
||||||
@@ -320,7 +322,7 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
### model result deal
|
### model result deal
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
|
||||||
prompt_define_response = None
|
prompt_define_response = None
|
||||||
prompt_define_response = (
|
prompt_define_response = (
|
||||||
self.prompt_template.output_parser.parse_prompt_response(
|
self.prompt_template.output_parser.parse_prompt_response(
|
||||||
@@ -596,7 +598,7 @@ def _load_system_message(
|
|||||||
prompt_template: PromptTemplate,
|
prompt_template: PromptTemplate,
|
||||||
str_message: bool = True,
|
str_message: bool = True,
|
||||||
):
|
):
|
||||||
system_convs = current_message.get_system_conv()
|
system_convs = current_message.get_system_messages()
|
||||||
system_text = ""
|
system_text = ""
|
||||||
system_messages = []
|
system_messages = []
|
||||||
for system_conv in system_convs:
|
for system_conv in system_convs:
|
||||||
@@ -614,7 +616,7 @@ def _load_user_message(
|
|||||||
prompt_template: PromptTemplate,
|
prompt_template: PromptTemplate,
|
||||||
str_message: bool = True,
|
str_message: bool = True,
|
||||||
):
|
):
|
||||||
user_conv = current_message.get_user_conv()
|
user_conv = current_message.get_latest_user_message()
|
||||||
user_messages = []
|
user_messages = []
|
||||||
if user_conv:
|
if user_conv:
|
||||||
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
|
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
|
||||||
|
@@ -70,7 +70,9 @@ class ChatHistoryManager:
|
|||||||
|
|
||||||
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
|
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self._chat_ctx.current_user_input)
|
self.current_message.add_user_message(
|
||||||
|
self._chat_ctx.current_user_input, check_duplicate_type=True
|
||||||
|
)
|
||||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||||
"%Y-%m-%d %H:%M:%S"
|
"%Y-%m-%d %H:%M:%S"
|
||||||
)
|
)
|
||||||
|
@@ -51,6 +51,12 @@ def install():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def db():
|
||||||
|
"""Manage your metadata database and your datasources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
stop_all_func_list = []
|
stop_all_func_list = []
|
||||||
|
|
||||||
|
|
||||||
@@ -64,6 +70,7 @@ def stop_all():
|
|||||||
cli.add_command(start)
|
cli.add_command(start)
|
||||||
cli.add_command(stop)
|
cli.add_command(stop)
|
||||||
cli.add_command(install)
|
cli.add_command(install)
|
||||||
|
cli.add_command(db)
|
||||||
add_command_alias(stop_all, name="all", parent_group=stop)
|
add_command_alias(stop_all, name="all", parent_group=stop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -96,10 +103,13 @@ try:
|
|||||||
start_webserver,
|
start_webserver,
|
||||||
stop_webserver,
|
stop_webserver,
|
||||||
_stop_all_dbgpt_server,
|
_stop_all_dbgpt_server,
|
||||||
|
migration,
|
||||||
)
|
)
|
||||||
|
|
||||||
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||||
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||||
|
# Add migration command
|
||||||
|
add_command_alias(migration, name="migration", parent_group=db)
|
||||||
stop_all_func_list.append(_stop_all_dbgpt_server)
|
stop_all_func_list.append(_stop_all_dbgpt_server)
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@@ -9,6 +9,10 @@ from dbgpt.core.interface.message import (
|
|||||||
ModelMessage,
|
ModelMessage,
|
||||||
ModelMessageRoleType,
|
ModelMessageRoleType,
|
||||||
OnceConversation,
|
OnceConversation,
|
||||||
|
StorageConversation,
|
||||||
|
MessageStorageItem,
|
||||||
|
ConversationIdentifier,
|
||||||
|
MessageIdentifier,
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
|
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
|
||||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||||
@@ -20,6 +24,16 @@ from dbgpt.core.interface.cache import (
|
|||||||
CachePolicy,
|
CachePolicy,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
)
|
)
|
||||||
|
from dbgpt.core.interface.storage import (
|
||||||
|
ResourceIdentifier,
|
||||||
|
StorageItem,
|
||||||
|
StorageItemAdapter,
|
||||||
|
StorageInterface,
|
||||||
|
InMemoryStorage,
|
||||||
|
DefaultStorageItemAdapter,
|
||||||
|
QuerySpec,
|
||||||
|
StorageError,
|
||||||
|
)
|
||||||
|
|
||||||
__ALL__ = [
|
__ALL__ = [
|
||||||
"ModelInferenceMetrics",
|
"ModelInferenceMetrics",
|
||||||
@@ -30,6 +44,10 @@ __ALL__ = [
|
|||||||
"ModelMessage",
|
"ModelMessage",
|
||||||
"ModelMessageRoleType",
|
"ModelMessageRoleType",
|
||||||
"OnceConversation",
|
"OnceConversation",
|
||||||
|
"StorageConversation",
|
||||||
|
"MessageStorageItem",
|
||||||
|
"ConversationIdentifier",
|
||||||
|
"MessageIdentifier",
|
||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"PromptTemplateOperator",
|
"PromptTemplateOperator",
|
||||||
"BaseOutputParser",
|
"BaseOutputParser",
|
||||||
@@ -41,4 +59,12 @@ __ALL__ = [
|
|||||||
"CacheClient",
|
"CacheClient",
|
||||||
"CachePolicy",
|
"CachePolicy",
|
||||||
"CacheConfig",
|
"CacheConfig",
|
||||||
|
"ResourceIdentifier",
|
||||||
|
"StorageItem",
|
||||||
|
"StorageItemAdapter",
|
||||||
|
"StorageInterface",
|
||||||
|
"InMemoryStorage",
|
||||||
|
"DefaultStorageItemAdapter",
|
||||||
|
"QuerySpec",
|
||||||
|
"StorageError",
|
||||||
]
|
]
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import threading
|
import threading
|
||||||
import asyncio
|
import asyncio
|
||||||
from ..dag import DAG, DAGContext
|
from ..base import DAG, DAGVar
|
||||||
|
|
||||||
|
|
||||||
def test_dag_context_sync():
|
def test_dag_context_sync():
|
||||||
@@ -9,18 +9,18 @@ def test_dag_context_sync():
|
|||||||
dag2 = DAG("dag2")
|
dag2 = DAG("dag2")
|
||||||
|
|
||||||
with dag1:
|
with dag1:
|
||||||
assert DAGContext.get_current_dag() == dag1
|
assert DAGVar.get_current_dag() == dag1
|
||||||
with dag2:
|
with dag2:
|
||||||
assert DAGContext.get_current_dag() == dag2
|
assert DAGVar.get_current_dag() == dag2
|
||||||
assert DAGContext.get_current_dag() == dag1
|
assert DAGVar.get_current_dag() == dag1
|
||||||
assert DAGContext.get_current_dag() is None
|
assert DAGVar.get_current_dag() is None
|
||||||
|
|
||||||
|
|
||||||
def test_dag_context_threading():
|
def test_dag_context_threading():
|
||||||
def thread_function(dag):
|
def thread_function(dag):
|
||||||
DAGContext.enter_dag(dag)
|
DAGVar.enter_dag(dag)
|
||||||
assert DAGContext.get_current_dag() == dag
|
assert DAGVar.get_current_dag() == dag
|
||||||
DAGContext.exit_dag()
|
DAGVar.exit_dag()
|
||||||
|
|
||||||
dag1 = DAG("dag1")
|
dag1 = DAG("dag1")
|
||||||
dag2 = DAG("dag2")
|
dag2 = DAG("dag2")
|
||||||
@@ -33,19 +33,19 @@ def test_dag_context_threading():
|
|||||||
thread1.join()
|
thread1.join()
|
||||||
thread2.join()
|
thread2.join()
|
||||||
|
|
||||||
assert DAGContext.get_current_dag() is None
|
assert DAGVar.get_current_dag() is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dag_context_async():
|
async def test_dag_context_async():
|
||||||
async def async_function(dag):
|
async def async_function(dag):
|
||||||
DAGContext.enter_dag(dag)
|
DAGVar.enter_dag(dag)
|
||||||
assert DAGContext.get_current_dag() == dag
|
assert DAGVar.get_current_dag() == dag
|
||||||
DAGContext.exit_dag()
|
DAGVar.exit_dag()
|
||||||
|
|
||||||
dag1 = DAG("dag1")
|
dag1 = DAG("dag1")
|
||||||
dag2 = DAG("dag2")
|
dag2 = DAG("dag2")
|
||||||
|
|
||||||
await asyncio.gather(async_function(dag1), async_function(dag2))
|
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||||
|
|
||||||
assert DAGContext.get_current_dag() is None
|
assert DAGVar.get_current_dag() is None
|
||||||
|
@@ -1,16 +1,26 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from dbgpt.core.interface.storage import (
|
||||||
|
ResourceIdentifier,
|
||||||
|
StorageItem,
|
||||||
|
StorageInterface,
|
||||||
|
InMemoryStorage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(BaseModel, ABC):
|
class BaseMessage(BaseModel, ABC):
|
||||||
"""Message object."""
|
"""Message object."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
|
index: int = 0
|
||||||
|
round_index: int = 0
|
||||||
|
"""The round index of the message in the conversation"""
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -18,6 +28,24 @@ class BaseMessage(BaseModel, ABC):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
"""Type of the message, used for serialization."""
|
"""Type of the message, used for serialization."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pass_to_model(self) -> bool:
|
||||||
|
"""Whether the message will be passed to the model"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Convert to dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: The dict object
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": self.type,
|
||||||
|
"data": self.dict(),
|
||||||
|
"index": self.index,
|
||||||
|
"round_index": self.round_index,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the human."""
|
"""Type of message that is spoken by the human."""
|
||||||
@@ -51,6 +79,14 @@ class ViewMessage(BaseMessage):
|
|||||||
"""Type of the message, used for serialization."""
|
"""Type of the message, used for serialization."""
|
||||||
return "view"
|
return "view"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pass_to_model(self) -> bool:
|
||||||
|
"""Whether the message will be passed to the model
|
||||||
|
|
||||||
|
The view message will not be passed to the model
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
class SystemMessage(BaseMessage):
|
||||||
"""Type of message that is a system message."""
|
"""Type of message that is a system message."""
|
||||||
@@ -141,15 +177,15 @@ class ModelMessage(BaseModel):
|
|||||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
def _message_to_dict(message: BaseMessage) -> Dict:
|
||||||
return {"type": message.type, "data": message.dict()}
|
return message.to_dict()
|
||||||
|
|
||||||
|
|
||||||
def _messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
|
||||||
return [_message_to_dict(m) for m in messages]
|
return [_message_to_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
def _message_from_dict(message: dict) -> BaseMessage:
|
def _message_from_dict(message: Dict) -> BaseMessage:
|
||||||
_type = message["type"]
|
_type = message["type"]
|
||||||
if _type == "human":
|
if _type == "human":
|
||||||
return HumanMessage(**message["data"])
|
return HumanMessage(**message["data"])
|
||||||
@@ -163,7 +199,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|||||||
raise ValueError(f"Got unexpected type: {_type}")
|
raise ValueError(f"Got unexpected type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
def _messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
|
||||||
return [_message_from_dict(m) for m in messages]
|
return [_message_from_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
@@ -193,50 +229,119 @@ def _parse_model_messages(
|
|||||||
history_messages.append([])
|
history_messages.append([])
|
||||||
if messages[-1].role != "human":
|
if messages[-1].role != "human":
|
||||||
raise ValueError("Hi! What do you want to talk about?")
|
raise ValueError("Hi! What do you want to talk about?")
|
||||||
# Keep message pair of [user message, assistant message]
|
# Keep message a pair of [user message, assistant message]
|
||||||
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
|
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
|
||||||
user_prompt = messages[-1].content
|
user_prompt = messages[-1].content
|
||||||
return user_prompt, system_messages, history_messages
|
return user_prompt, system_messages, history_messages
|
||||||
|
|
||||||
|
|
||||||
class OnceConversation:
|
class OnceConversation:
|
||||||
"""
|
"""All the information of a conversation, the current single service in memory,
|
||||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
can expand cache and database support distributed services.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
chat_mode: str,
|
||||||
|
user_name: str = None,
|
||||||
|
sys_code: str = None,
|
||||||
|
summary: str = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.chat_mode: str = chat_mode
|
self.chat_mode: str = chat_mode
|
||||||
self.messages: List[BaseMessage] = []
|
|
||||||
self.start_date: str = ""
|
|
||||||
self.chat_order: int = 0
|
|
||||||
self.model_name: str = ""
|
|
||||||
self.param_type: str = ""
|
|
||||||
self.param_value: str = ""
|
|
||||||
self.cost: int = 0
|
|
||||||
self.tokens: int = 0
|
|
||||||
self.user_name: str = user_name
|
self.user_name: str = user_name
|
||||||
self.sys_code: str = sys_code
|
self.sys_code: str = sys_code
|
||||||
|
self.summary: str = summary
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
self.messages: List[BaseMessage] = kwargs.get("messages", [])
|
||||||
"""Add a user message to the store"""
|
self.start_date: str = kwargs.get("start_date", "")
|
||||||
has_message = any(
|
# After each complete round of dialogue, the current value will be increased by 1
|
||||||
isinstance(instance, HumanMessage) for instance in self.messages
|
self.chat_order: int = int(kwargs.get("chat_order", 0))
|
||||||
)
|
self.model_name: str = kwargs.get("model_name", "")
|
||||||
if has_message:
|
self.param_type: str = kwargs.get("param_type", "")
|
||||||
raise ValueError("Already Have Human message")
|
self.param_value: str = kwargs.get("param_value", "")
|
||||||
self.messages.append(HumanMessage(content=message))
|
self.cost: int = int(kwargs.get("cost", 0))
|
||||||
|
self.tokens: int = int(kwargs.get("tokens", 0))
|
||||||
|
self._message_index: int = int(kwargs.get("message_index", 0))
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
def _append_message(self, message: BaseMessage) -> None:
|
||||||
"""Add an AI message to the store"""
|
index = self._message_index
|
||||||
|
self._message_index += 1
|
||||||
|
message.index = index
|
||||||
|
message.round_index = self.chat_order
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
def start_new_round(self) -> None:
|
||||||
|
"""Start a new round of conversation
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> conversation = OnceConversation()
|
||||||
|
>>> # The chat order will be 0, then we start a new round of conversation
|
||||||
|
>>> assert conversation.chat_order == 0
|
||||||
|
>>> conversation.start_new_round()
|
||||||
|
>>> # Now the chat order will be 1
|
||||||
|
>>> assert conversation.chat_order == 1
|
||||||
|
>>> conversation.add_user_message("hello")
|
||||||
|
>>> conversation.add_ai_message("hi")
|
||||||
|
>>> conversation.end_current_round()
|
||||||
|
>>> # Now the chat order will be 1, then we start a new round of conversation
|
||||||
|
>>> conversation.start_new_round()
|
||||||
|
>>> # Now the chat order will be 2
|
||||||
|
>>> assert conversation.chat_order == 2
|
||||||
|
>>> conversation.add_user_message("hello")
|
||||||
|
>>> conversation.add_ai_message("hi")
|
||||||
|
>>> conversation.end_current_round()
|
||||||
|
>>> assert conversation.chat_order == 2
|
||||||
|
"""
|
||||||
|
self.chat_order += 1
|
||||||
|
|
||||||
|
def end_current_round(self) -> None:
|
||||||
|
"""End the current round of conversation
|
||||||
|
|
||||||
|
We do noting here, just for the interface
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_user_message(
|
||||||
|
self, message: str, check_duplicate_type: Optional[bool] = False
|
||||||
|
) -> None:
|
||||||
|
"""Add a user message to the conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message content
|
||||||
|
check_duplicate_type (bool): Whether to check the duplicate message type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the message is duplicate and check_duplicate_type is True
|
||||||
|
"""
|
||||||
|
if check_duplicate_type:
|
||||||
|
has_message = any(
|
||||||
|
isinstance(instance, HumanMessage) for instance in self.messages
|
||||||
|
)
|
||||||
|
if has_message:
|
||||||
|
raise ValueError("Already Have Human message")
|
||||||
|
self._append_message(HumanMessage(content=message))
|
||||||
|
|
||||||
|
def add_ai_message(
|
||||||
|
self, message: str, update_if_exist: Optional[bool] = False
|
||||||
|
) -> None:
|
||||||
|
"""Add an AI message to the conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message content
|
||||||
|
update_if_exist (bool): Whether to update the message if the message type is duplicate
|
||||||
|
"""
|
||||||
|
if not update_if_exist:
|
||||||
|
self._append_message(AIMessage(content=message))
|
||||||
|
return
|
||||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||||
if has_message:
|
if has_message:
|
||||||
self.__update_ai_message(message)
|
self._update_ai_message(message)
|
||||||
else:
|
else:
|
||||||
self.messages.append(AIMessage(content=message))
|
self._append_message(AIMessage(content=message))
|
||||||
""" """
|
|
||||||
|
|
||||||
def __update_ai_message(self, new_message: str) -> None:
|
def _update_ai_message(self, new_message: str) -> None:
|
||||||
"""
|
"""
|
||||||
stream out message update
|
stream out message update
|
||||||
Args:
|
Args:
|
||||||
@@ -252,13 +357,11 @@ class OnceConversation:
|
|||||||
|
|
||||||
def add_view_message(self, message: str) -> None:
|
def add_view_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
self._append_message(ViewMessage(content=message))
|
||||||
self.messages.append(ViewMessage(content=message))
|
|
||||||
""" """
|
|
||||||
|
|
||||||
def add_system_message(self, message: str) -> None:
|
def add_system_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add a system message to the store"""
|
||||||
self.messages.append(SystemMessage(content=message))
|
self._append_message(SystemMessage(content=message))
|
||||||
|
|
||||||
def set_start_time(self, datatime: datetime):
|
def set_start_time(self, datatime: datetime):
|
||||||
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
@@ -267,23 +370,369 @@ class OnceConversation:
|
|||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Remove all messages from the store"""
|
"""Remove all messages from the store"""
|
||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
self.session_id = None
|
|
||||||
|
|
||||||
def get_user_conv(self):
|
def get_latest_user_message(self) -> Optional[HumanMessage]:
|
||||||
for message in self.messages:
|
"""Get the latest user message"""
|
||||||
|
for message in self.messages[::-1]:
|
||||||
if isinstance(message, HumanMessage):
|
if isinstance(message, HumanMessage):
|
||||||
return message
|
return message
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_system_conv(self):
|
def get_system_messages(self) -> List[SystemMessage]:
|
||||||
system_convs = []
|
"""Get the latest user message"""
|
||||||
|
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
|
||||||
|
|
||||||
|
def _to_dict(self) -> Dict:
|
||||||
|
return _conversation_to_dict(self)
|
||||||
|
|
||||||
|
def from_conversation(self, conversation: OnceConversation) -> None:
|
||||||
|
"""Load the conversation from the storage"""
|
||||||
|
self.chat_mode = conversation.chat_mode
|
||||||
|
self.messages = conversation.messages
|
||||||
|
self.start_date = conversation.start_date
|
||||||
|
self.chat_order = conversation.chat_order
|
||||||
|
self.model_name = conversation.model_name
|
||||||
|
self.param_type = conversation.param_type
|
||||||
|
self.param_value = conversation.param_value
|
||||||
|
self.cost = conversation.cost
|
||||||
|
self.tokens = conversation.tokens
|
||||||
|
self.user_name = conversation.user_name
|
||||||
|
self.sys_code = conversation.sys_code
|
||||||
|
|
||||||
|
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||||
|
"""Get the messages by round index
|
||||||
|
|
||||||
|
Args:
|
||||||
|
round_index (int): The round index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: The messages
|
||||||
|
"""
|
||||||
|
return list(filter(lambda x: x.round_index == round_index, self.messages))
|
||||||
|
|
||||||
|
def get_latest_round(self) -> List[BaseMessage]:
|
||||||
|
"""Get the latest round messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: The messages
|
||||||
|
"""
|
||||||
|
return self.get_messages_by_round(self.chat_order)
|
||||||
|
|
||||||
|
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
|
||||||
|
"""Get the messages with round count
|
||||||
|
|
||||||
|
If the round count is 1, the history messages will not be included.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
conversation = OnceConversation()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is the first round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is the second round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is the third round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
|
||||||
|
assert len(conversation.get_messages_with_round(1)) == 2
|
||||||
|
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
|
||||||
|
assert conversation.get_messages_with_round(1)[1].content == "hi"
|
||||||
|
|
||||||
|
assert len(conversation.get_messages_with_round(2)) == 4
|
||||||
|
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
|
||||||
|
assert conversation.get_messages_with_round(2)[1].content == "hi"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
round_count (int): The round count
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: The messages
|
||||||
|
"""
|
||||||
|
latest_round_index = self.chat_order
|
||||||
|
start_round_index = max(1, latest_round_index - round_count + 1)
|
||||||
|
messages = []
|
||||||
|
for round_index in range(start_round_index, latest_round_index + 1):
|
||||||
|
messages.extend(self.get_messages_by_round(round_index))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_model_messages(self) -> List[ModelMessage]:
|
||||||
|
"""Get the model messages
|
||||||
|
|
||||||
|
Model messages just include human, ai and system messages.
|
||||||
|
Model messages maybe include the history messages, The order of the messages is the same as the order of
|
||||||
|
the messages in the conversation, the last message is the latest message.
|
||||||
|
|
||||||
|
If you want to hand the message with your own logic, you can override this method.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
If you not need the history messages, you can override this method like this:
|
||||||
|
.. code-block:: python
|
||||||
|
def get_model_messages(self) -> List[ModelMessage]:
|
||||||
|
messages = []
|
||||||
|
for message in self.get_latest_round():
|
||||||
|
if message.pass_to_model:
|
||||||
|
messages.append(
|
||||||
|
ModelMessage(role=message.type, content=message.content)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
If you want to add the one round history messages, you can override this method like this:
|
||||||
|
.. code-block:: python
|
||||||
|
def get_model_messages(self) -> List[ModelMessage]:
|
||||||
|
messages = []
|
||||||
|
latest_round_index = self.chat_order
|
||||||
|
round_count = 1
|
||||||
|
start_round_index = max(1, latest_round_index - round_count + 1)
|
||||||
|
for round_index in range(start_round_index, latest_round_index + 1):
|
||||||
|
for message in self.get_messages_by_round(round_index):
|
||||||
|
if message.pass_to_model:
|
||||||
|
messages.append(
|
||||||
|
ModelMessage(role=message.type, content=message.content)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ModelMessage]: The model messages
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
if isinstance(message, SystemMessage):
|
if message.pass_to_model:
|
||||||
system_convs.append(message)
|
messages.append(
|
||||||
return system_convs
|
ModelMessage(role=message.type, content=message.content)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def _conversation_to_dict(once: OnceConversation) -> dict:
|
class ConversationIdentifier(ResourceIdentifier):
|
||||||
|
"""Conversation identifier"""
|
||||||
|
|
||||||
|
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
|
||||||
|
self.conv_uid = conv_uid
|
||||||
|
self.identifier_type = identifier_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_identifier(self) -> str:
|
||||||
|
return f"{self.identifier_type}:{self.conv_uid}"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
|
||||||
|
|
||||||
|
|
||||||
|
class MessageIdentifier(ResourceIdentifier):
|
||||||
|
"""Message identifier"""
|
||||||
|
|
||||||
|
identifier_split = "___"
|
||||||
|
|
||||||
|
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
|
||||||
|
self.conv_uid = conv_uid
|
||||||
|
self.index = index
|
||||||
|
self.identifier_type = identifier_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_identifier(self) -> str:
|
||||||
|
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
|
||||||
|
"""Convert from str identifier
|
||||||
|
|
||||||
|
Args:
|
||||||
|
str_identifier (str): The str identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MessageIdentifier: The message identifier
|
||||||
|
"""
|
||||||
|
parts = str_identifier.split(MessageIdentifier.identifier_split)
|
||||||
|
if len(parts) != 3:
|
||||||
|
raise ValueError(f"Invalid str identifier: {str_identifier}")
|
||||||
|
return MessageIdentifier(parts[1], int(parts[2]))
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"conv_uid": self.conv_uid,
|
||||||
|
"index": self.index,
|
||||||
|
"identifier_type": self.identifier_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MessageStorageItem(StorageItem):
|
||||||
|
@property
|
||||||
|
def identifier(self) -> MessageIdentifier:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
|
||||||
|
self.conv_uid = conv_uid
|
||||||
|
self.index = index
|
||||||
|
self.message_detail = message_detail
|
||||||
|
self._id = MessageIdentifier(conv_uid, index)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"conv_uid": self.conv_uid,
|
||||||
|
"index": self.index,
|
||||||
|
"message_detail": self.message_detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_message(self) -> BaseMessage:
|
||||||
|
"""Convert to message object
|
||||||
|
Returns:
|
||||||
|
BaseMessage: The message object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the message type is not supported
|
||||||
|
"""
|
||||||
|
return _message_from_dict(self.message_detail)
|
||||||
|
|
||||||
|
def merge(self, other: "StorageItem") -> None:
|
||||||
|
"""Merge the other message to self
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (StorageItem): The other message
|
||||||
|
"""
|
||||||
|
if not isinstance(other, MessageStorageItem):
|
||||||
|
raise ValueError(f"Can not merge {other} to {self}")
|
||||||
|
self.message_detail = other.message_detail
|
||||||
|
|
||||||
|
|
||||||
|
class StorageConversation(OnceConversation, StorageItem):
|
||||||
|
"""All the information of a conversation, the current single service in memory,
|
||||||
|
can expand cache and database support distributed services.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self) -> ConversationIdentifier:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
dict_data = self._to_dict()
|
||||||
|
messages: Dict = dict_data.pop("messages")
|
||||||
|
message_ids = []
|
||||||
|
index = 0
|
||||||
|
for message in messages:
|
||||||
|
if "index" in message:
|
||||||
|
message_idx = message["index"]
|
||||||
|
else:
|
||||||
|
message_idx = index
|
||||||
|
index += 1
|
||||||
|
message_ids.append(
|
||||||
|
MessageIdentifier(self.conv_uid, message_idx).str_identifier
|
||||||
|
)
|
||||||
|
# Replace message with message ids
|
||||||
|
dict_data["conv_uid"] = self.conv_uid
|
||||||
|
dict_data["message_ids"] = message_ids
|
||||||
|
dict_data["save_message_independent"] = self.save_message_independent
|
||||||
|
return dict_data
|
||||||
|
|
||||||
|
def merge(self, other: "StorageItem") -> None:
|
||||||
|
"""Merge the other conversation to self
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (StorageItem): The other conversation
|
||||||
|
"""
|
||||||
|
if not isinstance(other, StorageConversation):
|
||||||
|
raise ValueError(f"Can not merge {other} to {self}")
|
||||||
|
self.from_conversation(other)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conv_uid: str,
|
||||||
|
chat_mode: str = None,
|
||||||
|
user_name: str = None,
|
||||||
|
sys_code: str = None,
|
||||||
|
message_ids: List[str] = None,
|
||||||
|
summary: str = None,
|
||||||
|
save_message_independent: Optional[bool] = True,
|
||||||
|
conv_storage: StorageInterface = None,
|
||||||
|
message_storage: StorageInterface = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
|
||||||
|
self.conv_uid = conv_uid
|
||||||
|
self._message_ids = message_ids
|
||||||
|
self.save_message_independent = save_message_independent
|
||||||
|
self._id = ConversationIdentifier(conv_uid)
|
||||||
|
if conv_storage is None:
|
||||||
|
conv_storage = InMemoryStorage()
|
||||||
|
if message_storage is None:
|
||||||
|
message_storage = InMemoryStorage()
|
||||||
|
self.conv_storage = conv_storage
|
||||||
|
self.message_storage = message_storage
|
||||||
|
# Load from storage
|
||||||
|
self.load_from_storage(self.conv_storage, self.message_storage)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_ids(self) -> List[str]:
|
||||||
|
"""Get the message ids
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: The message ids
|
||||||
|
"""
|
||||||
|
return self._message_ids if self._message_ids else []
|
||||||
|
|
||||||
|
def end_current_round(self) -> None:
|
||||||
|
"""End the current round of conversation
|
||||||
|
|
||||||
|
Save the conversation to the storage after a round of conversation
|
||||||
|
"""
|
||||||
|
self.save_to_storage()
|
||||||
|
|
||||||
|
def _get_message_items(self) -> List[MessageStorageItem]:
|
||||||
|
return [
|
||||||
|
MessageStorageItem(self.conv_uid, message.index, message.to_dict())
|
||||||
|
for message in self.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
def save_to_storage(self) -> None:
|
||||||
|
"""Save the conversation to the storage"""
|
||||||
|
# Save messages first
|
||||||
|
message_list = self._get_message_items()
|
||||||
|
self._message_ids = [
|
||||||
|
message.identifier.str_identifier for message in message_list
|
||||||
|
]
|
||||||
|
self.message_storage.save_list(message_list)
|
||||||
|
# Save conversation
|
||||||
|
self.conv_storage.save_or_update(self)
|
||||||
|
|
||||||
|
def load_from_storage(
|
||||||
|
self, conv_storage: StorageInterface, message_storage: StorageInterface
|
||||||
|
) -> None:
|
||||||
|
"""Load the conversation from the storage
|
||||||
|
|
||||||
|
Warning: This will overwrite the current conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conv_storage (StorageInterface): The storage interface
|
||||||
|
message_storage (StorageInterface): The storage interface
|
||||||
|
"""
|
||||||
|
# Load conversation first
|
||||||
|
conversation: StorageConversation = conv_storage.load(
|
||||||
|
self._id, StorageConversation
|
||||||
|
)
|
||||||
|
if conversation is None:
|
||||||
|
return
|
||||||
|
message_ids = conversation._message_ids or []
|
||||||
|
|
||||||
|
# Load messages
|
||||||
|
message_list = message_storage.load_list(
|
||||||
|
[
|
||||||
|
MessageIdentifier.from_str_identifier(message_id)
|
||||||
|
for message_id in message_ids
|
||||||
|
],
|
||||||
|
MessageStorageItem,
|
||||||
|
)
|
||||||
|
messages = [message.to_message() for message in message_list]
|
||||||
|
conversation.messages = messages
|
||||||
|
self._message_ids = message_ids
|
||||||
|
self.from_conversation(conversation)
|
||||||
|
|
||||||
|
|
||||||
|
def _conversation_to_dict(once: OnceConversation) -> Dict:
|
||||||
start_str: str = ""
|
start_str: str = ""
|
||||||
if hasattr(once, "start_date") and once.start_date:
|
if hasattr(once, "start_date") and once.start_date:
|
||||||
if isinstance(once.start_date, datetime):
|
if isinstance(once.start_date, datetime):
|
||||||
@@ -303,6 +752,7 @@ def _conversation_to_dict(once: OnceConversation) -> dict:
|
|||||||
"param_value": once.param_value,
|
"param_value": once.param_value,
|
||||||
"user_name": once.user_name,
|
"user_name": once.user_name,
|
||||||
"sys_code": once.sys_code,
|
"sys_code": once.sys_code,
|
||||||
|
"summary": once.summary if once.summary else "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -92,7 +92,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
|||||||
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
|
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def __illegal_json_ends(self, s):
|
def _illegal_json_ends(self, s):
|
||||||
temp_json = s
|
temp_json = s
|
||||||
illegal_json_ends_1 = [", }", ",}"]
|
illegal_json_ends_1 = [", }", ",}"]
|
||||||
illegal_json_ends_2 = ", ]", ",]"
|
illegal_json_ends_2 = ", ]", ",]"
|
||||||
@@ -102,25 +102,25 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
|||||||
temp_json = temp_json.replace(illegal_json_end, " ]")
|
temp_json = temp_json.replace(illegal_json_end, " ]")
|
||||||
return temp_json
|
return temp_json
|
||||||
|
|
||||||
def __extract_json(self, s):
|
def _extract_json(self, s):
|
||||||
try:
|
try:
|
||||||
# Get the dual-mode analysis first and get the maximum result
|
# Get the dual-mode analysis first and get the maximum result
|
||||||
temp_json_simple = self.__json_interception(s)
|
temp_json_simple = self._json_interception(s)
|
||||||
temp_json_array = self.__json_interception(s, True)
|
temp_json_array = self._json_interception(s, True)
|
||||||
if len(temp_json_simple) > len(temp_json_array):
|
if len(temp_json_simple) > len(temp_json_array):
|
||||||
temp_json = temp_json_simple
|
temp_json = temp_json_simple
|
||||||
else:
|
else:
|
||||||
temp_json = temp_json_array
|
temp_json = temp_json_array
|
||||||
|
|
||||||
if not temp_json:
|
if not temp_json:
|
||||||
temp_json = self.__json_interception(s)
|
temp_json = self._json_interception(s)
|
||||||
|
|
||||||
temp_json = self.__illegal_json_ends(temp_json)
|
temp_json = self._illegal_json_ends(temp_json)
|
||||||
return temp_json
|
return temp_json
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
||||||
|
|
||||||
def __json_interception(self, s, is_json_array: bool = False):
|
def _json_interception(self, s, is_json_array: bool = False):
|
||||||
try:
|
try:
|
||||||
if is_json_array:
|
if is_json_array:
|
||||||
i = s.find("[")
|
i = s.find("[")
|
||||||
@@ -176,7 +176,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
|||||||
cleaned_output = cleaned_output.strip()
|
cleaned_output = cleaned_output.strip()
|
||||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||||
logger.info("illegal json processing:\n" + cleaned_output)
|
logger.info("illegal json processing:\n" + cleaned_output)
|
||||||
cleaned_output = self.__extract_json(cleaned_output)
|
cleaned_output = self._extract_json(cleaned_output)
|
||||||
|
|
||||||
if not cleaned_output or len(cleaned_output) <= 0:
|
if not cleaned_output or len(cleaned_output) <= 0:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
@@ -188,7 +188,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
|||||||
.replace("\\", " ")
|
.replace("\\", " ")
|
||||||
.replace("\_", "_")
|
.replace("\_", "_")
|
||||||
)
|
)
|
||||||
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
cleaned_output = self._illegal_json_ends(cleaned_output)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
def parse_view_response(
|
def parse_view_response(
|
||||||
@@ -208,20 +208,6 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
|||||||
"""Instructions on how the LLM output should be formatted."""
|
"""Instructions on how the LLM output should be formatted."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# @property
|
|
||||||
# def _type(self) -> str:
|
|
||||||
# """Return the type key."""
|
|
||||||
# raise NotImplementedError(
|
|
||||||
# f"_type property is not implemented in class {self.__class__.__name__}."
|
|
||||||
# " This is required for serialization."
|
|
||||||
# )
|
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
|
||||||
"""Return dictionary representation of output parser."""
|
|
||||||
output_parser_dict = super().dict()
|
|
||||||
output_parser_dict["_type"] = self._type
|
|
||||||
return output_parser_dict
|
|
||||||
|
|
||||||
async def map(self, input_value: ModelOutput) -> Any:
|
async def map(self, input_value: ModelOutput) -> Any:
|
||||||
"""Parse the output of an LLM call.
|
"""Parse the output of an LLM call.
|
||||||
|
|
||||||
|
@@ -1,19 +1,34 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Type, Dict
|
from typing import Type, Dict
|
||||||
|
|
||||||
|
|
||||||
class Serializable(ABC):
|
class Serializable(ABC):
|
||||||
|
serializer: "Serializer" = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Convert the object's state to a dictionary."""
|
||||||
|
|
||||||
def serialize(self) -> bytes:
|
def serialize(self) -> bytes:
|
||||||
"""Convert the object into bytes for storage or transmission.
|
"""Convert the object into bytes for storage or transmission.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bytes: The byte array after serialization
|
bytes: The byte array after serialization
|
||||||
"""
|
"""
|
||||||
|
if self.serializer is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Serializer is not set. Please set the serializer before serialization."
|
||||||
|
)
|
||||||
|
return self.serializer.serialize(self)
|
||||||
|
|
||||||
@abstractmethod
|
def set_serializer(self, serializer: "Serializer") -> None:
|
||||||
def to_dict(self) -> Dict:
|
"""Set the serializer for current serializable object.
|
||||||
"""Convert the object's state to a dictionary."""
|
|
||||||
|
Args:
|
||||||
|
serializer (Serializer): The serializer to set
|
||||||
|
"""
|
||||||
|
self.serializer = serializer
|
||||||
|
|
||||||
|
|
||||||
class Serializer(ABC):
|
class Serializer(ABC):
|
||||||
|
409
dbgpt/core/interface/storage.py
Normal file
409
dbgpt/core/interface/storage.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||||
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||||
|
from dbgpt.util.annotations import PublicAPI
|
||||||
|
from dbgpt.util.pagination_utils import PaginationResult
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="beta")
|
||||||
|
class ResourceIdentifier(Serializable, ABC):
|
||||||
|
"""The resource identifier interface for resource identifiers."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def str_identifier(self) -> str:
|
||||||
|
"""Get the string identifier of the resource.
|
||||||
|
|
||||||
|
The string identifier is used to uniquely identify the resource.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The string identifier of the resource
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return the hash value of the key."""
|
||||||
|
return hash(self.str_identifier)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Check equality with another key."""
|
||||||
|
if not isinstance(other, ResourceIdentifier):
|
||||||
|
return False
|
||||||
|
return self.str_identifier == other.str_identifier
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="beta")
|
||||||
|
class StorageItem(Serializable, ABC):
|
||||||
|
"""The storage item interface for storage items."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def identifier(self) -> ResourceIdentifier:
|
||||||
|
"""Get the resource identifier of the storage item.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResourceIdentifier: The resource identifier of the storage item
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def merge(self, other: "StorageItem") -> None:
|
||||||
|
"""Merge the other storage item into the current storage item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (StorageItem): The other storage item
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=StorageItem)
|
||||||
|
TDataRepresentation = TypeVar("TDataRepresentation")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageItemAdapter(Generic[T, TDataRepresentation]):
|
||||||
|
"""The storage item adapter for converting storage items to and from the storage format.
|
||||||
|
|
||||||
|
Sometimes, the storage item is not the same as the storage format,
|
||||||
|
so we need to convert the storage item to the storage format and vice versa.
|
||||||
|
|
||||||
|
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_storage_format(self, item: T) -> TDataRepresentation:
|
||||||
|
"""Convert the storage item to the storage format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item (T): The storage item
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TDataRepresentation: The data in the storage format
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_storage_format(self, data: TDataRepresentation) -> T:
|
||||||
|
"""Convert the storage format to the storage item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (TDataRepresentation): The data in the storage format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The storage item
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_query_for_identifier(
|
||||||
|
self,
|
||||||
|
storage_format: Type[TDataRepresentation],
|
||||||
|
resource_id: ResourceIdentifier,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Get the query for the resource identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_format (Type[TDataRepresentation]): The storage format
|
||||||
|
resource_id (ResourceIdentifier): The resource identifier
|
||||||
|
kwargs: The additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The query for the resource identifier
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
|
||||||
|
"""The default storage item adapter for converting storage items to and from the storage format.
|
||||||
|
|
||||||
|
The storage item is the same as the storage format, so no conversion is required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def to_storage_format(self, item: T) -> T:
|
||||||
|
return item
|
||||||
|
|
||||||
|
def from_storage_format(self, data: T) -> T:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_query_for_identifier(
|
||||||
|
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
|
||||||
|
) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="beta")
|
||||||
|
class StorageError(Exception):
|
||||||
|
"""The base exception class for storage errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="beta")
|
||||||
|
class QuerySpec:
|
||||||
|
"""The query specification for querying data from the storage.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
conditions (Dict[str, Any]): The conditions for querying data
|
||||||
|
limit (int): The maximum number of data to return
|
||||||
|
offset (int): The offset of the data to return
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
|
||||||
|
) -> None:
|
||||||
|
self.conditions = conditions
|
||||||
|
self.limit = limit
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="beta")
|
||||||
|
class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||||
|
"""The storage interface for storing and loading data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
serializer: Optional[Serializer] = None,
|
||||||
|
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
|
||||||
|
):
|
||||||
|
self._serializer = serializer or JsonSerializer()
|
||||||
|
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> Serializer:
|
||||||
|
"""Get the serializer of the storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Serializer: The serializer of the storage
|
||||||
|
"""
|
||||||
|
return self._serializer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]:
|
||||||
|
"""Get the adapter of the storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage
|
||||||
|
"""
|
||||||
|
return self._storage_item_adapter
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, data: T) -> None:
|
||||||
|
"""Save the data to the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (T): The data to save
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If the data already exists in the storage or data is None
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, data: T) -> None:
|
||||||
|
"""Update the data to the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (T): The data to save
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If data is None
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_or_update(self, data: T) -> None:
|
||||||
|
"""Save or update the data to the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (T): The data to save
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If data is None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_list(self, data: List[T]) -> None:
|
||||||
|
"""Save the data to the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (T): The data to save
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If the data already exists in the storage or data is None
|
||||||
|
"""
|
||||||
|
for d in data:
|
||||||
|
self.save(d)
|
||||||
|
|
||||||
|
def save_or_update_list(self, data: List[T]) -> None:
|
||||||
|
"""Save or update the data to the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (T): The data to save
|
||||||
|
"""
|
||||||
|
for d in data:
|
||||||
|
self.save_or_update(d)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||||
|
"""Load the data from the storage.
|
||||||
|
|
||||||
|
None will be returned if the data does not exist in the storage.
|
||||||
|
|
||||||
|
Load data with resource_id will be faster than query data with conditions,
|
||||||
|
so we suggest to use load if possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[T]: The loaded data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
|
||||||
|
"""Load the data from the storage.
|
||||||
|
|
||||||
|
None will be returned if the data does not exist in the storage.
|
||||||
|
|
||||||
|
Load data with resource_id will be faster than query data with conditions,
|
||||||
|
so we suggest to use load if possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[T]: The loaded data
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for r in resource_id:
|
||||||
|
item = self.load(r, cls)
|
||||||
|
if item is not None:
|
||||||
|
result.append(item)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||||
|
"""Delete the data from the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||||
|
"""Query data from the storage.
|
||||||
|
|
||||||
|
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (QuerySpec): The query specification
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[T]: The queried data
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||||
|
"""Count the number of data from the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (QuerySpec): The query specification
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def paginate_query(
|
||||||
|
self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None
|
||||||
|
) -> PaginationResult[T]:
|
||||||
|
"""Paginate the query result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page (int): The page number
|
||||||
|
page_size (int): The number of items per page
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PaginationResult[T]: The pagination result
|
||||||
|
"""
|
||||||
|
if spec is None:
|
||||||
|
spec = QuerySpec(conditions={})
|
||||||
|
spec.limit = page_size
|
||||||
|
spec.offset = (page - 1) * page_size
|
||||||
|
items = self.query(spec, cls)
|
||||||
|
total = self.count(spec, cls)
|
||||||
|
return PaginationResult(
|
||||||
|
items=items,
|
||||||
|
total_count=total,
|
||||||
|
total_pages=(total + page_size - 1) // page_size,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
|
class InMemoryStorage(StorageInterface[T, T]):
|
||||||
|
"""The in-memory storage for storing and loading data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
serializer: Optional[Serializer] = None,
|
||||||
|
):
|
||||||
|
super().__init__(serializer)
|
||||||
|
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
|
||||||
|
|
||||||
|
def save(self, data: T) -> None:
|
||||||
|
if not data:
|
||||||
|
raise StorageError("Data cannot be None")
|
||||||
|
if not data.serializer:
|
||||||
|
data.set_serializer(self.serializer)
|
||||||
|
|
||||||
|
if data.identifier.str_identifier in self._data:
|
||||||
|
raise StorageError(
|
||||||
|
f"Data with identifier {data.identifier.str_identifier} already exists"
|
||||||
|
)
|
||||||
|
self._data[data.identifier.str_identifier] = data.serialize()
|
||||||
|
|
||||||
|
def update(self, data: T) -> None:
|
||||||
|
if not data:
|
||||||
|
raise StorageError("Data cannot be None")
|
||||||
|
if not data.serializer:
|
||||||
|
data.set_serializer(self.serializer)
|
||||||
|
self._data[data.identifier.str_identifier] = data.serialize()
|
||||||
|
|
||||||
|
def save_or_update(self, data: T) -> None:
|
||||||
|
self.update(data)
|
||||||
|
|
||||||
|
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||||
|
serialized_data = self._data.get(resource_id.str_identifier)
|
||||||
|
if serialized_data is None:
|
||||||
|
return None
|
||||||
|
return self.serializer.deserialize(serialized_data, cls)
|
||||||
|
|
||||||
|
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||||
|
if resource_id.str_identifier in self._data:
|
||||||
|
del self._data[resource_id.str_identifier]
|
||||||
|
|
||||||
|
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||||
|
result = []
|
||||||
|
for serialized_data in self._data.values():
|
||||||
|
data = self._serializer.deserialize(serialized_data, cls)
|
||||||
|
if all(
|
||||||
|
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||||
|
):
|
||||||
|
result.append(data)
|
||||||
|
|
||||||
|
# Apply limit and offset
|
||||||
|
if spec.limit is not None:
|
||||||
|
result = result[spec.offset : spec.offset + spec.limit]
|
||||||
|
else:
|
||||||
|
result = result[spec.offset :]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||||
|
count = 0
|
||||||
|
for serialized_data in self._data.values():
|
||||||
|
data = self._serializer.deserialize(serialized_data, cls)
|
||||||
|
if all(
|
||||||
|
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||||
|
):
|
||||||
|
count += 1
|
||||||
|
return count
|
0
dbgpt/core/interface/tests/__init__.py
Normal file
0
dbgpt/core/interface/tests/__init__.py
Normal file
14
dbgpt/core/interface/tests/conftest.py
Normal file
14
dbgpt/core/interface/tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from dbgpt.core.interface.storage import InMemoryStorage
|
||||||
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def serializer():
|
||||||
|
return JsonSerializer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def in_memory_storage(serializer):
|
||||||
|
return InMemoryStorage(serializer)
|
307
dbgpt/core/interface/tests/test_message.py
Normal file
307
dbgpt/core/interface/tests/test_message.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||||
|
from dbgpt.core.interface.message import *
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_conversation():
|
||||||
|
return OnceConversation(chat_mode="chat_normal", user_name="user1", sys_code="sys1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def human_message():
|
||||||
|
return HumanMessage(content="Hello")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ai_message():
|
||||||
|
return AIMessage(content="Hi there")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def system_message():
|
||||||
|
return SystemMessage(content="System update")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def view_message():
|
||||||
|
return ViewMessage(content="View this")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_identifier():
|
||||||
|
return ConversationIdentifier("conv1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def message_identifier():
|
||||||
|
return MessageIdentifier("conv1", 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def message_storage_item():
|
||||||
|
message = HumanMessage(content="Hello", index=1)
|
||||||
|
message_detail = message.to_dict()
|
||||||
|
return MessageStorageItem("conv1", 1, message_detail)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage_conversation():
|
||||||
|
return StorageConversation("conv1", chat_mode="chat_normal", user_name="user1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_with_messages():
|
||||||
|
conv = OnceConversation(chat_mode="chat_normal", user_name="user1")
|
||||||
|
conv.start_new_round()
|
||||||
|
conv.add_user_message("Hello")
|
||||||
|
conv.add_ai_message("Hi")
|
||||||
|
conv.end_current_round()
|
||||||
|
|
||||||
|
conv.start_new_round()
|
||||||
|
conv.add_user_message("How are you?")
|
||||||
|
conv.add_ai_message("I'm good, thanks")
|
||||||
|
conv.end_current_round()
|
||||||
|
|
||||||
|
return conv
|
||||||
|
|
||||||
|
|
||||||
|
def test_init(basic_conversation):
|
||||||
|
assert basic_conversation.chat_mode == "chat_normal"
|
||||||
|
assert basic_conversation.user_name == "user1"
|
||||||
|
assert basic_conversation.sys_code == "sys1"
|
||||||
|
assert basic_conversation.messages == []
|
||||||
|
assert basic_conversation.start_date == ""
|
||||||
|
assert basic_conversation.chat_order == 0
|
||||||
|
assert basic_conversation.model_name == ""
|
||||||
|
assert basic_conversation.param_type == ""
|
||||||
|
assert basic_conversation.param_value == ""
|
||||||
|
assert basic_conversation.cost == 0
|
||||||
|
assert basic_conversation.tokens == 0
|
||||||
|
assert basic_conversation._message_index == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_user_message(basic_conversation, human_message):
|
||||||
|
basic_conversation.add_user_message(human_message.content)
|
||||||
|
assert len(basic_conversation.messages) == 1
|
||||||
|
assert isinstance(basic_conversation.messages[0], HumanMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_ai_message(basic_conversation, ai_message):
|
||||||
|
basic_conversation.add_ai_message(ai_message.content)
|
||||||
|
assert len(basic_conversation.messages) == 1
|
||||||
|
assert isinstance(basic_conversation.messages[0], AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_system_message(basic_conversation, system_message):
|
||||||
|
basic_conversation.add_system_message(system_message.content)
|
||||||
|
assert len(basic_conversation.messages) == 1
|
||||||
|
assert isinstance(basic_conversation.messages[0], SystemMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_view_message(basic_conversation, view_message):
|
||||||
|
basic_conversation.add_view_message(view_message.content)
|
||||||
|
assert len(basic_conversation.messages) == 1
|
||||||
|
assert isinstance(basic_conversation.messages[0], ViewMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_start_time(basic_conversation):
|
||||||
|
now = datetime.now()
|
||||||
|
basic_conversation.set_start_time(now)
|
||||||
|
assert basic_conversation.start_date == now.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_messages(basic_conversation, human_message):
|
||||||
|
basic_conversation.add_user_message(human_message.content)
|
||||||
|
basic_conversation.clear()
|
||||||
|
assert len(basic_conversation.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_latest_user_message(basic_conversation, human_message):
|
||||||
|
basic_conversation.add_user_message(human_message.content)
|
||||||
|
latest_message = basic_conversation.get_latest_user_message()
|
||||||
|
assert latest_message == human_message
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_system_messages(basic_conversation, system_message):
|
||||||
|
basic_conversation.add_system_message(system_message.content)
|
||||||
|
system_messages = basic_conversation.get_system_messages()
|
||||||
|
assert len(system_messages) == 1
|
||||||
|
assert system_messages[0] == system_message
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_conversation(basic_conversation):
|
||||||
|
new_conversation = OnceConversation(chat_mode="chat_advanced", user_name="user2")
|
||||||
|
basic_conversation.from_conversation(new_conversation)
|
||||||
|
assert basic_conversation.chat_mode == "chat_advanced"
|
||||||
|
assert basic_conversation.user_name == "user2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_messages_by_round(conversation_with_messages):
|
||||||
|
# Test first round
|
||||||
|
round1_messages = conversation_with_messages.get_messages_by_round(1)
|
||||||
|
assert len(round1_messages) == 2
|
||||||
|
assert round1_messages[0].content == "Hello"
|
||||||
|
assert round1_messages[1].content == "Hi"
|
||||||
|
|
||||||
|
# Test not existing round
|
||||||
|
no_messages = conversation_with_messages.get_messages_by_round(3)
|
||||||
|
assert len(no_messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_latest_round(conversation_with_messages):
|
||||||
|
latest_round_messages = conversation_with_messages.get_latest_round()
|
||||||
|
assert len(latest_round_messages) == 2
|
||||||
|
assert latest_round_messages[0].content == "How are you?"
|
||||||
|
assert latest_round_messages[1].content == "I'm good, thanks"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_messages_with_round(conversation_with_messages):
|
||||||
|
# Test last round
|
||||||
|
last_round_messages = conversation_with_messages.get_messages_with_round(1)
|
||||||
|
assert len(last_round_messages) == 2
|
||||||
|
assert last_round_messages[0].content == "How are you?"
|
||||||
|
assert last_round_messages[1].content == "I'm good, thanks"
|
||||||
|
|
||||||
|
# Test last two rounds
|
||||||
|
last_two_rounds_messages = conversation_with_messages.get_messages_with_round(2)
|
||||||
|
assert len(last_two_rounds_messages) == 4
|
||||||
|
assert last_two_rounds_messages[0].content == "Hello"
|
||||||
|
assert last_two_rounds_messages[1].content == "Hi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_messages(conversation_with_messages):
|
||||||
|
model_messages = conversation_with_messages.get_model_messages()
|
||||||
|
assert len(model_messages) == 4
|
||||||
|
assert all(isinstance(msg, ModelMessage) for msg in model_messages)
|
||||||
|
assert model_messages[0].content == "Hello"
|
||||||
|
assert model_messages[1].content == "Hi"
|
||||||
|
assert model_messages[2].content == "How are you?"
|
||||||
|
assert model_messages[3].content == "I'm good, thanks"
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_identifier(conversation_identifier):
|
||||||
|
assert conversation_identifier.conv_uid == "conv1"
|
||||||
|
assert conversation_identifier.identifier_type == "conversation"
|
||||||
|
assert conversation_identifier.str_identifier == "conversation:conv1"
|
||||||
|
assert conversation_identifier.to_dict() == {
|
||||||
|
"conv_uid": "conv1",
|
||||||
|
"identifier_type": "conversation",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_identifier(message_identifier):
|
||||||
|
assert message_identifier.conv_uid == "conv1"
|
||||||
|
assert message_identifier.index == 1
|
||||||
|
assert message_identifier.identifier_type == "message"
|
||||||
|
assert message_identifier.str_identifier == "message___conv1___1"
|
||||||
|
assert message_identifier.to_dict() == {
|
||||||
|
"conv_uid": "conv1",
|
||||||
|
"index": 1,
|
||||||
|
"identifier_type": "message",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_storage_item(message_storage_item):
|
||||||
|
assert message_storage_item.conv_uid == "conv1"
|
||||||
|
assert message_storage_item.index == 1
|
||||||
|
assert message_storage_item.message_detail == {
|
||||||
|
"type": "human",
|
||||||
|
"data": {
|
||||||
|
"content": "Hello",
|
||||||
|
"index": 1,
|
||||||
|
"round_index": 0,
|
||||||
|
"additional_kwargs": {},
|
||||||
|
"example": False,
|
||||||
|
},
|
||||||
|
"index": 1,
|
||||||
|
"round_index": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert isinstance(message_storage_item.identifier, MessageIdentifier)
|
||||||
|
assert message_storage_item.to_dict() == {
|
||||||
|
"conv_uid": "conv1",
|
||||||
|
"index": 1,
|
||||||
|
"message_detail": {
|
||||||
|
"type": "human",
|
||||||
|
"index": 1,
|
||||||
|
"data": {
|
||||||
|
"content": "Hello",
|
||||||
|
"index": 1,
|
||||||
|
"round_index": 0,
|
||||||
|
"additional_kwargs": {},
|
||||||
|
"example": False,
|
||||||
|
},
|
||||||
|
"round_index": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert isinstance(message_storage_item.to_message(), BaseMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_conversation_init(storage_conversation):
|
||||||
|
assert storage_conversation.conv_uid == "conv1"
|
||||||
|
assert storage_conversation.chat_mode == "chat_normal"
|
||||||
|
assert storage_conversation.user_name == "user1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_conversation_add_user_message(storage_conversation):
|
||||||
|
storage_conversation.add_user_message("Hi")
|
||||||
|
assert len(storage_conversation.messages) == 1
|
||||||
|
assert isinstance(storage_conversation.messages[0], HumanMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_conversation_add_ai_message(storage_conversation):
|
||||||
|
storage_conversation.add_ai_message("Hello")
|
||||||
|
assert len(storage_conversation.messages) == 1
|
||||||
|
assert isinstance(storage_conversation.messages[0], AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_to_storage(storage_conversation, in_memory_storage):
|
||||||
|
# Set storage
|
||||||
|
storage_conversation.conv_storage = in_memory_storage
|
||||||
|
storage_conversation.message_storage = in_memory_storage
|
||||||
|
|
||||||
|
# Add messages
|
||||||
|
storage_conversation.add_user_message("User message")
|
||||||
|
storage_conversation.add_ai_message("AI response")
|
||||||
|
|
||||||
|
# Save to storage
|
||||||
|
storage_conversation.save_to_storage()
|
||||||
|
|
||||||
|
# Create a new StorageConversation instance to load the data
|
||||||
|
saved_conversation = StorageConversation(
|
||||||
|
storage_conversation.conv_uid,
|
||||||
|
conv_storage=in_memory_storage,
|
||||||
|
message_storage=in_memory_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert saved_conversation.conv_uid == storage_conversation.conv_uid
|
||||||
|
assert len(saved_conversation.messages) == 2
|
||||||
|
assert isinstance(saved_conversation.messages[0], HumanMessage)
|
||||||
|
assert isinstance(saved_conversation.messages[1], AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_storage(storage_conversation, in_memory_storage):
|
||||||
|
# Set storage
|
||||||
|
storage_conversation.conv_storage = in_memory_storage
|
||||||
|
storage_conversation.message_storage = in_memory_storage
|
||||||
|
|
||||||
|
# Add messages and save to storage
|
||||||
|
storage_conversation.add_user_message("User message")
|
||||||
|
storage_conversation.add_ai_message("AI response")
|
||||||
|
storage_conversation.save_to_storage()
|
||||||
|
|
||||||
|
# Create a new StorageConversation instance to load the data
|
||||||
|
new_conversation = StorageConversation(
|
||||||
|
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the data is loaded correctly
|
||||||
|
assert new_conversation.conv_uid == storage_conversation.conv_uid
|
||||||
|
assert len(new_conversation.messages) == 2
|
||||||
|
assert new_conversation.messages[0].content == "User message"
|
||||||
|
assert new_conversation.messages[1].content == "AI response"
|
||||||
|
assert isinstance(new_conversation.messages[0], HumanMessage)
|
||||||
|
assert isinstance(new_conversation.messages[1], AIMessage)
|
129
dbgpt/core/interface/tests/test_storage.py
Normal file
129
dbgpt/core/interface/tests/test_storage.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import Dict, Type, Union
|
||||||
|
from dbgpt.core.interface.storage import (
|
||||||
|
ResourceIdentifier,
|
||||||
|
StorageError,
|
||||||
|
QuerySpec,
|
||||||
|
InMemoryStorage,
|
||||||
|
StorageItem,
|
||||||
|
)
|
||||||
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class MockResourceIdentifier(ResourceIdentifier):
|
||||||
|
def __init__(self, identifier: str):
|
||||||
|
self._identifier = identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_identifier(self) -> str:
|
||||||
|
return self._identifier
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return {"identifier": self._identifier}
|
||||||
|
|
||||||
|
|
||||||
|
class MockStorageItem(StorageItem):
|
||||||
|
def merge(self, other: "StorageItem") -> None:
|
||||||
|
if not isinstance(other, MockStorageItem):
|
||||||
|
raise ValueError("other must be a MockStorageItem")
|
||||||
|
self.data = other.data
|
||||||
|
|
||||||
|
def __init__(self, identifier: Union[str, MockResourceIdentifier], data):
|
||||||
|
self._identifier_str = (
|
||||||
|
identifier if isinstance(identifier, str) else identifier.str_identifier
|
||||||
|
)
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return {"identifier": self._identifier_str, "data": self.data}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self) -> ResourceIdentifier:
|
||||||
|
return MockResourceIdentifier(self._identifier_str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def serializer():
|
||||||
|
return JsonSerializer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def in_memory_storage(serializer):
|
||||||
|
return InMemoryStorage(serializer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load(in_memory_storage):
|
||||||
|
resource_id = MockResourceIdentifier("1")
|
||||||
|
item = MockStorageItem(resource_id, "test_data")
|
||||||
|
|
||||||
|
in_memory_storage.save(item)
|
||||||
|
|
||||||
|
loaded_item = in_memory_storage.load(resource_id, MockStorageItem)
|
||||||
|
assert loaded_item.data == "test_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_save(in_memory_storage):
|
||||||
|
item = MockStorageItem("1", "test_data")
|
||||||
|
|
||||||
|
in_memory_storage.save(item)
|
||||||
|
|
||||||
|
# Should raise StorageError when saving the same data
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
in_memory_storage.save(item)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(in_memory_storage):
|
||||||
|
resource_id = MockResourceIdentifier("1")
|
||||||
|
item = MockStorageItem(resource_id, "test_data")
|
||||||
|
|
||||||
|
in_memory_storage.save(item)
|
||||||
|
in_memory_storage.delete(resource_id)
|
||||||
|
# Storage should not contain the data after deletion
|
||||||
|
assert in_memory_storage.load(resource_id, MockStorageItem) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_query(in_memory_storage):
|
||||||
|
resource_id1 = MockResourceIdentifier("1")
|
||||||
|
item1 = MockStorageItem(resource_id1, "test_data1")
|
||||||
|
|
||||||
|
resource_id2 = MockResourceIdentifier("2")
|
||||||
|
item2 = MockStorageItem(resource_id2, "test_data2")
|
||||||
|
|
||||||
|
in_memory_storage.save(item1)
|
||||||
|
in_memory_storage.save(item2)
|
||||||
|
|
||||||
|
query_spec = QuerySpec(conditions={"data": "test_data1"})
|
||||||
|
results = in_memory_storage.query(query_spec, MockStorageItem)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].data == "test_data1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_count(in_memory_storage):
|
||||||
|
item1 = MockStorageItem("1", "test_data1")
|
||||||
|
|
||||||
|
item2 = MockStorageItem("2", "test_data2")
|
||||||
|
|
||||||
|
in_memory_storage.save(item1)
|
||||||
|
in_memory_storage.save(item2)
|
||||||
|
|
||||||
|
query_spec = QuerySpec(conditions={})
|
||||||
|
count = in_memory_storage.count(query_spec, MockStorageItem)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_paginate_query(in_memory_storage):
|
||||||
|
for i in range(10):
|
||||||
|
resource_id = MockResourceIdentifier(str(i))
|
||||||
|
item = MockStorageItem(resource_id, f"test_data{i}")
|
||||||
|
in_memory_storage.save(item)
|
||||||
|
|
||||||
|
page_size = 3
|
||||||
|
query_spec = QuerySpec(conditions={})
|
||||||
|
page_result = in_memory_storage.paginate_query(
|
||||||
|
2, page_size, MockStorageItem, query_spec
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(page_result.items) == page_size
|
||||||
|
assert page_result.total_count == 10
|
||||||
|
assert page_result.total_pages == 4
|
||||||
|
assert page_result.page == 2
|
@@ -91,6 +91,10 @@ class BaseConnect(ABC):
|
|||||||
"""Get column fields about specified table."""
|
"""Get column fields about specified table."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_simple_fields(self, table_name):
|
||||||
|
"""Get column fields about specified table."""
|
||||||
|
return self.get_fields(table_name)
|
||||||
|
|
||||||
def get_show_create_table(self, table_name):
|
def get_show_create_table(self, table_name):
|
||||||
"""Get the creation table sql about specified table."""
|
"""Get the creation table sql about specified table."""
|
||||||
pass
|
pass
|
||||||
|
@@ -1,16 +1,10 @@
|
|||||||
from sqlalchemy import Column, Integer, String, Index, Text, text
|
from sqlalchemy import Column, Integer, String, Index, Text, text
|
||||||
from sqlalchemy import UniqueConstraint
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectConfigEntity(Base):
|
class ConnectConfigEntity(Model):
|
||||||
"""db connect config entity"""
|
"""db connect config entity"""
|
||||||
|
|
||||||
__tablename__ = "connect_config"
|
__tablename__ = "connect_config"
|
||||||
@@ -38,17 +32,9 @@ class ConnectConfigEntity(Base):
|
|||||||
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||||
"""db connect config dao"""
|
"""db connect config dao"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, entity: ConnectConfigEntity):
|
def update(self, entity: ConnectConfigEntity):
|
||||||
"""update db connect info"""
|
"""update db connect info"""
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
try:
|
try:
|
||||||
updated = session.merge(entity)
|
updated = session.merge(entity)
|
||||||
session.commit()
|
session.commit()
|
||||||
@@ -58,7 +44,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
|
|
||||||
def delete(self, db_name: int):
|
def delete(self, db_name: int):
|
||||||
""" "delete db connect info"""
|
""" "delete db connect info"""
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if db_name is None:
|
if db_name is None:
|
||||||
raise Exception("db_name is None")
|
raise Exception("db_name is None")
|
||||||
|
|
||||||
@@ -70,7 +56,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
|
|
||||||
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
|
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
|
||||||
"""get db connect info by name"""
|
"""get db connect info by name"""
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
db_connect = session.query(ConnectConfigEntity)
|
db_connect = session.query(ConnectConfigEntity)
|
||||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||||
result = db_connect.first()
|
result = db_connect.first()
|
||||||
@@ -99,7 +85,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
comment: comment
|
comment: comment
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
@@ -144,7 +130,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
old_db_conf = self.get_db_config(db_name)
|
old_db_conf = self.get_db_config(db_name)
|
||||||
if old_db_conf:
|
if old_db_conf:
|
||||||
try:
|
try:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if not db_path:
|
if not db_path:
|
||||||
update_statement = text(
|
update_statement = text(
|
||||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||||
@@ -164,7 +150,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||||
"""add file db connect info"""
|
"""add file db connect info"""
|
||||||
try:
|
try:
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
insert_statement = text(
|
insert_statement = text(
|
||||||
"""
|
"""
|
||||||
INSERT INTO connect_config(
|
INSERT INTO connect_config(
|
||||||
@@ -194,7 +180,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
|
|
||||||
def get_db_config(self, db_name):
|
def get_db_config(self, db_name):
|
||||||
"""get db config by name"""
|
"""get db config by name"""
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
if db_name:
|
if db_name:
|
||||||
select_statement = text(
|
select_statement = text(
|
||||||
"""
|
"""
|
||||||
@@ -221,7 +207,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
|
|
||||||
def get_db_list(self):
|
def get_db_list(self):
|
||||||
"""get db list"""
|
"""get db list"""
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
result = session.execute(text("SELECT * FROM connect_config"))
|
result = session.execute(text("SELECT * FROM connect_config"))
|
||||||
|
|
||||||
fields = [field[0] for field in result.cursor.description]
|
fields = [field[0] for field in result.cursor.description]
|
||||||
@@ -235,7 +221,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
|||||||
|
|
||||||
def delete_db(self, db_name):
|
def delete_db(self, db_name):
|
||||||
"""delete db connect info"""
|
"""delete db connect info"""
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
|
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
|
||||||
params = {"db_name": db_name}
|
params = {"db_name": db_name}
|
||||||
session.execute(delete_statement, params)
|
session.execute(delete_statement, params)
|
||||||
|
@@ -270,7 +270,12 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
"""Format the error message"""
|
"""Format the error message"""
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
def __write(self, write_sql):
|
def _write(self, write_sql: str):
|
||||||
|
"""Run a SQL write command and return the results as a list of tuples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
write_sql (str): SQL write command to run
|
||||||
|
"""
|
||||||
print(f"Write[{write_sql}]")
|
print(f"Write[{write_sql}]")
|
||||||
db_cache = self._engine.url.database
|
db_cache = self._engine.url.database
|
||||||
result = self.session.execute(text(write_sql))
|
result = self.session.execute(text(write_sql))
|
||||||
@@ -280,16 +285,12 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|
||||||
def __query(self, query, fetch: str = "all"):
|
def _query(self, query: str, fetch: str = "all"):
|
||||||
"""
|
"""Run a SQL query and return the results as a list of tuples.
|
||||||
only for query
|
|
||||||
Args:
|
Args:
|
||||||
session:
|
query (str): SQL query to run
|
||||||
query:
|
fetch (str): fetch type
|
||||||
fetch:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
print(f"Query[{query}]")
|
print(f"Query[{query}]")
|
||||||
if not query:
|
if not query:
|
||||||
@@ -308,6 +309,10 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
result.insert(0, field_names)
|
result.insert(0, field_names)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def query_table_schema(self, table_name):
|
||||||
|
sql = f"select * from {table_name} limit 1"
|
||||||
|
return self._query(sql)
|
||||||
|
|
||||||
def query_ex(self, query, fetch: str = "all"):
|
def query_ex(self, query, fetch: str = "all"):
|
||||||
"""
|
"""
|
||||||
only for query
|
only for query
|
||||||
@@ -325,7 +330,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
elif fetch == "one":
|
elif fetch == "one":
|
||||||
result = cursor.fetchone()[0] # type: ignore
|
result = cursor.fetchone() # type: ignore
|
||||||
else:
|
else:
|
||||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||||
field_names = list(i[0:] for i in cursor.keys())
|
field_names = list(i[0:] for i in cursor.keys())
|
||||||
@@ -342,12 +347,12 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||||
if ttype == sqlparse.tokens.DML:
|
if ttype == sqlparse.tokens.DML:
|
||||||
if sql_type == "SELECT":
|
if sql_type == "SELECT":
|
||||||
return self.__query(command, fetch)
|
return self._query(command, fetch)
|
||||||
else:
|
else:
|
||||||
self.__write(command)
|
self._write(command)
|
||||||
select_sql = self.convert_sql_write_to_select(command)
|
select_sql = self.convert_sql_write_to_select(command)
|
||||||
print(f"write result query:{select_sql}")
|
print(f"write result query:{select_sql}")
|
||||||
return self.__query(select_sql)
|
return self._query(select_sql)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"DDL execution determines whether to enable through configuration ")
|
print(f"DDL execution determines whether to enable through configuration ")
|
||||||
@@ -360,10 +365,11 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
result.insert(0, field_names)
|
result.insert(0, field_names)
|
||||||
print("DDL Result:" + str(result))
|
print("DDL Result:" + str(result))
|
||||||
if not result:
|
if not result:
|
||||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||||
|
return self.get_simple_fields(table_name)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
return self.get_simple_fields(table_name)
|
||||||
|
|
||||||
def run_to_df(self, command: str, fetch: str = "all"):
|
def run_to_df(self, command: str, fetch: str = "all"):
|
||||||
result_lst = self.run(command, fetch)
|
result_lst = self.run(command, fetch)
|
||||||
@@ -451,13 +457,23 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
sql = sql.strip()
|
sql = sql.strip()
|
||||||
parsed = sqlparse.parse(sql)[0]
|
parsed = sqlparse.parse(sql)[0]
|
||||||
sql_type = parsed.get_type()
|
sql_type = parsed.get_type()
|
||||||
table_name = parsed.get_name()
|
if sql_type == "CREATE":
|
||||||
|
table_name = self._extract_table_name_from_ddl(parsed)
|
||||||
|
else:
|
||||||
|
table_name = parsed.get_name()
|
||||||
|
|
||||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||||
ttype = first_token.ttype
|
ttype = first_token.ttype
|
||||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||||
return parsed, ttype, sql_type, table_name
|
return parsed, ttype, sql_type, table_name
|
||||||
|
|
||||||
|
def _extract_table_name_from_ddl(self, parsed):
|
||||||
|
"""Extract table name from CREATE TABLE statement.""" ""
|
||||||
|
for token in parsed.tokens:
|
||||||
|
if token.ttype is None and isinstance(token, sqlparse.sql.Identifier):
|
||||||
|
return token.get_real_name()
|
||||||
|
return None
|
||||||
|
|
||||||
def get_indexes(self, table_name):
|
def get_indexes(self, table_name):
|
||||||
"""Get table indexes about specified table."""
|
"""Get table indexes about specified table."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
@@ -485,6 +501,10 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
fields = cursor.fetchall()
|
fields = cursor.fetchall()
|
||||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||||
|
|
||||||
|
def get_simple_fields(self, table_name):
|
||||||
|
"""Get column fields about specified table."""
|
||||||
|
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||||
|
|
||||||
def get_charset(self):
|
def get_charset(self):
|
||||||
"""Get character_set."""
|
"""Get character_set."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
|
@@ -56,6 +56,10 @@ class SQLiteConnect(RDBMSDatabase):
|
|||||||
print(fields)
|
print(fields)
|
||||||
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
||||||
|
|
||||||
|
def get_simple_fields(self, table_name):
|
||||||
|
"""Get column fields about specified table."""
|
||||||
|
return self.get_fields(table_name)
|
||||||
|
|
||||||
def get_users(self):
|
def get_users(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -88,8 +92,9 @@ class SQLiteConnect(RDBMSDatabase):
|
|||||||
self._metadata.reflect(bind=self._engine)
|
self._metadata.reflect(bind=self._engine)
|
||||||
return self._all_tables
|
return self._all_tables
|
||||||
|
|
||||||
def _write(self, session, write_sql):
|
def _write(self, write_sql):
|
||||||
print(f"Write[{write_sql}]")
|
print(f"Write[{write_sql}]")
|
||||||
|
session = self.session
|
||||||
result = session.execute(text(write_sql))
|
result = session.execute(text(write_sql))
|
||||||
session.commit()
|
session.commit()
|
||||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||||
|
@@ -25,41 +25,41 @@ def test_get_table_info(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_get_table_info_with_table(db):
|
def test_get_table_info_with_table(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
db.run("CREATE TABLE test (id INTEGER);")
|
||||||
print(db._sync_tables_from_db())
|
print(db._sync_tables_from_db())
|
||||||
table_info = db.get_table_info()
|
table_info = db.get_table_info()
|
||||||
assert "CREATE TABLE test" in table_info
|
assert "CREATE TABLE test" in table_info
|
||||||
|
|
||||||
|
|
||||||
def test_run_sql(db):
|
def test_run_sql(db):
|
||||||
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
result = db.run("CREATE TABLE test(id INTEGER);")
|
||||||
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
|
assert result[0] == ("id", "INTEGER", 0, None, 0)
|
||||||
|
|
||||||
|
|
||||||
def test_run_no_throw(db):
|
def test_run_no_throw(db):
|
||||||
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
|
assert db.run_no_throw("this is a error sql").startswith("Error:")
|
||||||
|
|
||||||
|
|
||||||
def test_get_indexes(db):
|
def test_get_indexes(db):
|
||||||
db.run(db.session, "CREATE TABLE test (name TEXT);")
|
db.run("CREATE TABLE test (name TEXT);")
|
||||||
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
|
db.run("CREATE INDEX idx_name ON test(name);")
|
||||||
assert db.get_indexes("test") == [("idx_name", "c")]
|
assert db.get_indexes("test") == [("idx_name", "c")]
|
||||||
|
|
||||||
|
|
||||||
def test_get_indexes_empty(db):
|
def test_get_indexes_empty(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
assert db.get_indexes("test") == []
|
assert db.get_indexes("test") == []
|
||||||
|
|
||||||
|
|
||||||
def test_get_show_create_table(db):
|
def test_get_show_create_table(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
assert (
|
assert (
|
||||||
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
|
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_fields(db):
|
def test_get_fields(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
|
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
|
||||||
|
|
||||||
|
|
||||||
@@ -72,26 +72,26 @@ def test_get_collation(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_table_simple_info(db):
|
def test_table_simple_info(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
assert db.table_simple_info() == ["test(id);"]
|
assert db.table_simple_info() == ["test(id);"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_table_info_no_throw(db):
|
def test_get_table_info_no_throw(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
|
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
|
||||||
|
|
||||||
|
|
||||||
def test_query_ex(db):
|
def test_query_ex(db):
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
db.run(db.session, "insert into test(id) values (1)")
|
db.run("insert into test(id) values (1)")
|
||||||
db.run(db.session, "insert into test(id) values (2)")
|
db.run("insert into test(id) values (2)")
|
||||||
field_names, result = db.query_ex(db.session, "select * from test")
|
field_names, result = db.query_ex("select * from test")
|
||||||
assert field_names == ["id"]
|
assert field_names == ["id"]
|
||||||
assert result == [(1,), (2,)]
|
assert result == [(1,), (2,)]
|
||||||
|
|
||||||
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
|
field_names, result = db.query_ex("select * from test", fetch="one")
|
||||||
assert field_names == ["id"]
|
assert field_names == ["id"]
|
||||||
assert result == [(1,)]
|
assert result == [1]
|
||||||
|
|
||||||
|
|
||||||
def test_convert_sql_write_to_select(db):
|
def test_convert_sql_write_to_select(db):
|
||||||
@@ -109,7 +109,7 @@ def test_get_users(db):
|
|||||||
|
|
||||||
def test_get_table_comments(db):
|
def test_get_table_comments(db):
|
||||||
assert db.get_table_comments() == []
|
assert db.get_table_comments() == []
|
||||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
assert db.get_table_comments() == [
|
assert db.get_table_comments() == [
|
||||||
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
||||||
]
|
]
|
||||||
|
@@ -4,6 +4,7 @@ from aioresponses import aioresponses
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from httpx import AsyncClient, HTTPError
|
from httpx import AsyncClient, HTTPError
|
||||||
|
import importlib.metadata as metadata
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.util.openai_utils import chat_completion_stream, chat_completion
|
from dbgpt.util.openai_utils import chat_completion_stream, chat_completion
|
||||||
@@ -190,12 +191,26 @@ async def test_chat_completions_with_openai_lib_async_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
stream_stream_resp = ""
|
stream_stream_resp = ""
|
||||||
async for stream_resp in await openai.ChatCompletion.acreate(
|
if metadata.version("openai") >= "1.0.0":
|
||||||
model=model_name,
|
from openai import OpenAI
|
||||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
|
||||||
stream=True,
|
client = OpenAI(
|
||||||
):
|
**{"base_url": "http://test/api/v1", "api_key": client_api_key}
|
||||||
|
)
|
||||||
|
res = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = openai.ChatCompletion.acreate(
|
||||||
|
model=model_name,
|
||||||
|
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for stream_resp in res:
|
||||||
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
|
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
|
||||||
|
|
||||||
assert stream_stream_resp == expected_messages
|
assert stream_stream_resp == expected_messages
|
||||||
|
|
||||||
|
|
||||||
|
22
dbgpt/storage/cache/llm_cache.py
vendored
22
dbgpt/storage/cache/llm_cache.py
vendored
@@ -75,9 +75,8 @@ class LLMCacheValueData:
|
|||||||
|
|
||||||
|
|
||||||
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._serializer = serializer
|
|
||||||
self.config = LLMCacheKeyData(**kwargs)
|
self.config = LLMCacheKeyData(**kwargs)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
@@ -96,30 +95,23 @@ class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
|||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
return asdict(self.config)
|
return asdict(self.config)
|
||||||
|
|
||||||
def serialize(self) -> bytes:
|
|
||||||
return self._serializer.serialize(self)
|
|
||||||
|
|
||||||
def get_value(self) -> LLMCacheKeyData:
|
def get_value(self) -> LLMCacheKeyData:
|
||||||
return self.config
|
return self.config
|
||||||
|
|
||||||
|
|
||||||
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._serializer = serializer
|
|
||||||
self.value = LLMCacheValueData.from_dict(**kwargs)
|
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
return self.value.to_dict()
|
return self.value.to_dict()
|
||||||
|
|
||||||
def serialize(self) -> bytes:
|
|
||||||
return self._serializer.serialize(self)
|
|
||||||
|
|
||||||
def get_value(self) -> LLMCacheValueData:
|
def get_value(self) -> LLMCacheValueData:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"vaue: {str(self.value)}"
|
return f"value: {str(self.value)}"
|
||||||
|
|
||||||
|
|
||||||
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||||
@@ -146,7 +138,11 @@ class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
|||||||
return await self.get(key, cache_config) is not None
|
return await self.get(key, cache_config) is not None
|
||||||
|
|
||||||
def new_key(self, **kwargs) -> LLMCacheKey:
|
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||||
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
|
key = LLMCacheKey(**kwargs)
|
||||||
|
key.set_serializer(self._cache_manager.serializer)
|
||||||
|
return key
|
||||||
|
|
||||||
def new_value(self, **kwargs) -> LLMCacheValue:
|
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||||
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)
|
value = LLMCacheValue(**kwargs)
|
||||||
|
value.set_serializer(self._cache_manager.serializer)
|
||||||
|
return value
|
||||||
|
@@ -1,17 +1,12 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from sqlalchemy import Column, Integer, String, Index, Text
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, Integer, String, Index, Text, DateTime
|
||||||
from sqlalchemy import UniqueConstraint
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
from dbgpt.storage.metadata import BaseDao
|
from dbgpt.storage.metadata import BaseDao, Model
|
||||||
from dbgpt.storage.metadata.meta_data import (
|
|
||||||
Base,
|
|
||||||
engine,
|
|
||||||
session,
|
|
||||||
META_DATA_DATABASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatHistoryEntity(Base):
|
class ChatHistoryEntity(Model):
|
||||||
__tablename__ = "chat_history"
|
__tablename__ = "chat_history"
|
||||||
id = Column(
|
id = Column(
|
||||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||||
@@ -20,38 +15,60 @@ class ChatHistoryEntity(Base):
|
|||||||
"mysql_charset": "utf8mb4",
|
"mysql_charset": "utf8mb4",
|
||||||
"mysql_collate": "utf8mb4_unicode_ci",
|
"mysql_collate": "utf8mb4_unicode_ci",
|
||||||
}
|
}
|
||||||
|
conv_uid = Column(
|
||||||
|
String(255),
|
||||||
|
unique=True,
|
||||||
|
nullable=False,
|
||||||
|
comment="Conversation record unique id",
|
||||||
|
)
|
||||||
|
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
|
||||||
|
summary = Column(String(255), nullable=False, comment="Conversation record summary")
|
||||||
|
user_name = Column(String(255), nullable=True, comment="interlocutor")
|
||||||
|
messages = Column(
|
||||||
|
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
|
||||||
|
)
|
||||||
|
message_ids = Column(
|
||||||
|
Text(length=2**31 - 1), nullable=True, comment="Message ids, split by comma"
|
||||||
|
)
|
||||||
|
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||||
|
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||||
|
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||||
|
|
||||||
|
Index("idx_q_user", "user_name")
|
||||||
|
Index("idx_q_mode", "chat_mode")
|
||||||
|
Index("idx_q_conv", "summary")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryMessageEntity(Model):
|
||||||
|
__tablename__ = "chat_history_message"
|
||||||
|
id = Column(
|
||||||
|
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||||
|
)
|
||||||
|
__table_args__ = {
|
||||||
|
"mysql_charset": "utf8mb4",
|
||||||
|
"mysql_collate": "utf8mb4_unicode_ci",
|
||||||
|
}
|
||||||
conv_uid = Column(
|
conv_uid = Column(
|
||||||
String(255),
|
String(255),
|
||||||
unique=False,
|
unique=False,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
comment="Conversation record unique id",
|
comment="Conversation record unique id",
|
||||||
)
|
)
|
||||||
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
|
index = Column(Integer, nullable=False, comment="Message index")
|
||||||
summary = Column(String(255), nullable=False, comment="Conversation record summary")
|
round_index = Column(Integer, nullable=False, comment="Message round index")
|
||||||
user_name = Column(String(255), nullable=True, comment="interlocutor")
|
message_detail = Column(
|
||||||
messages = Column(
|
Text(length=2**31 - 1), nullable=True, comment="Message details, json format"
|
||||||
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
|
|
||||||
)
|
)
|
||||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||||
UniqueConstraint("conv_uid", name="uk_conversation")
|
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||||
Index("idx_q_user", "user_name")
|
UniqueConstraint("conv_uid", "index", name="uk_conversation_message")
|
||||||
Index("idx_q_mode", "chat_mode")
|
|
||||||
Index("idx_q_conv", "summary")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
database=META_DATA_DATABASE,
|
|
||||||
orm_base=Base,
|
|
||||||
db_engine=engine,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_last_20(
|
def list_last_20(
|
||||||
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
|
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||||
):
|
):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
chat_history = session.query(ChatHistoryEntity)
|
chat_history = session.query(ChatHistoryEntity)
|
||||||
if user_name:
|
if user_name:
|
||||||
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
|
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
|
||||||
@@ -65,7 +82,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def update(self, entity: ChatHistoryEntity):
|
def update(self, entity: ChatHistoryEntity):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
try:
|
try:
|
||||||
updated = session.merge(entity)
|
updated = session.merge(entity)
|
||||||
session.commit()
|
session.commit()
|
||||||
@@ -74,7 +91,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def update_message_by_uid(self, message: str, conv_uid: str):
|
def update_message_by_uid(self, message: str, conv_uid: str):
|
||||||
session = self.get_session()
|
session = self.get_raw_session()
|
||||||
try:
|
try:
|
||||||
chat_history = session.query(ChatHistoryEntity)
|
chat_history = session.query(ChatHistoryEntity)
|
||||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||||
@@ -85,20 +102,12 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def delete(self, conv_uid: int):
|
def delete(self, conv_uid: int):
|
||||||
session = self.get_session()
|
|
||||||
if conv_uid is None:
|
if conv_uid is None:
|
||||||
raise Exception("conv_uid is None")
|
raise Exception("conv_uid is None")
|
||||||
|
with self.session() as session:
|
||||||
chat_history = session.query(ChatHistoryEntity)
|
chat_history = session.query(ChatHistoryEntity)
|
||||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||||
chat_history.delete()
|
chat_history.delete()
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
|
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
|
||||||
session = self.get_session()
|
return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first()
|
||||||
chat_history = session.query(ChatHistoryEntity)
|
|
||||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
|
||||||
result = chat_history.first()
|
|
||||||
session.close()
|
|
||||||
return result
|
|
||||||
|
116
dbgpt/storage/chat_history/storage_adapter.py
Normal file
116
dbgpt/storage/chat_history/storage_adapter.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from typing import List, Dict, Type
|
||||||
|
import json
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||||
|
from dbgpt.core.interface.message import (
|
||||||
|
StorageConversation,
|
||||||
|
ConversationIdentifier,
|
||||||
|
MessageIdentifier,
|
||||||
|
MessageStorageItem,
|
||||||
|
_messages_from_dict,
|
||||||
|
_conversation_to_dict,
|
||||||
|
BaseMessage,
|
||||||
|
)
|
||||||
|
from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity
|
||||||
|
|
||||||
|
|
||||||
|
class DBStorageConversationItemAdapter(
|
||||||
|
StorageItemAdapter[StorageConversation, ChatHistoryEntity]
|
||||||
|
):
|
||||||
|
def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity:
|
||||||
|
message_ids = ",".join(item.message_ids)
|
||||||
|
messages = None
|
||||||
|
if not item.save_message_independent and item.messages:
|
||||||
|
messages = _conversation_to_dict(item)
|
||||||
|
return ChatHistoryEntity(
|
||||||
|
conv_uid=item.conv_uid,
|
||||||
|
chat_mode=item.chat_mode,
|
||||||
|
summary=item.summary or item.get_latest_user_message().content,
|
||||||
|
user_name=item.user_name,
|
||||||
|
# We not save messages to chat_history table in new design
|
||||||
|
messages=messages,
|
||||||
|
message_ids=message_ids,
|
||||||
|
sys_code=item.sys_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation:
|
||||||
|
message_ids = model.message_ids.split(",") if model.message_ids else []
|
||||||
|
old_conversations: List[Dict] = (
|
||||||
|
json.loads(model.messages) if model.messages else []
|
||||||
|
)
|
||||||
|
old_messages = []
|
||||||
|
save_message_independent = True
|
||||||
|
if old_conversations:
|
||||||
|
# Load old messages from old conversations, in old design, we save messages to chat_history table
|
||||||
|
old_messages_dict = []
|
||||||
|
for old_conversation in old_conversations:
|
||||||
|
old_messages_dict.extend(
|
||||||
|
old_conversation["messages"]
|
||||||
|
if "messages" in old_conversation
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
save_message_independent = False
|
||||||
|
old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict)
|
||||||
|
return StorageConversation(
|
||||||
|
conv_uid=model.conv_uid,
|
||||||
|
chat_mode=model.chat_mode,
|
||||||
|
summary=model.summary,
|
||||||
|
user_name=model.user_name,
|
||||||
|
message_ids=message_ids,
|
||||||
|
sys_code=model.sys_code,
|
||||||
|
save_message_independent=save_message_independent,
|
||||||
|
messages=old_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_query_for_identifier(
|
||||||
|
self,
|
||||||
|
storage_format: Type[ChatHistoryEntity],
|
||||||
|
resource_id: ConversationIdentifier,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
session: Session = kwargs.get("session")
|
||||||
|
if session is None:
|
||||||
|
raise Exception("session is None")
|
||||||
|
return session.query(ChatHistoryEntity).filter(
|
||||||
|
ChatHistoryEntity.conv_uid == resource_id.conv_uid
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DBMessageStorageItemAdapter(
|
||||||
|
StorageItemAdapter[MessageStorageItem, ChatHistoryMessageEntity]
|
||||||
|
):
|
||||||
|
def to_storage_format(self, item: MessageStorageItem) -> ChatHistoryMessageEntity:
|
||||||
|
round_index = item.message_detail.get("round_index", 0)
|
||||||
|
message_detail = json.dumps(item.message_detail, ensure_ascii=False)
|
||||||
|
return ChatHistoryMessageEntity(
|
||||||
|
conv_uid=item.conv_uid,
|
||||||
|
index=item.index,
|
||||||
|
round_index=round_index,
|
||||||
|
message_detail=message_detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_storage_format(
|
||||||
|
self, model: ChatHistoryMessageEntity
|
||||||
|
) -> MessageStorageItem:
|
||||||
|
message_detail = (
|
||||||
|
json.loads(model.message_detail) if model.message_detail else {}
|
||||||
|
)
|
||||||
|
return MessageStorageItem(
|
||||||
|
conv_uid=model.conv_uid,
|
||||||
|
index=model.index,
|
||||||
|
message_detail=message_detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_query_for_identifier(
|
||||||
|
self,
|
||||||
|
storage_format: Type[ChatHistoryMessageEntity],
|
||||||
|
resource_id: MessageIdentifier,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
session: Session = kwargs.get("session")
|
||||||
|
if session is None:
|
||||||
|
raise Exception("session is None")
|
||||||
|
return session.query(ChatHistoryMessageEntity).filter(
|
||||||
|
ChatHistoryMessageEntity.conv_uid == resource_id.conv_uid,
|
||||||
|
ChatHistoryMessageEntity.index == resource_id.index,
|
||||||
|
)
|
@@ -87,7 +87,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
[
|
[
|
||||||
self.chat_seesion_id,
|
self.chat_seesion_id,
|
||||||
once_message.chat_mode,
|
once_message.chat_mode,
|
||||||
once_message.get_user_conv().content,
|
once_message.get_latest_user_message().content,
|
||||||
once_message.user_name,
|
once_message.user_name,
|
||||||
once_message.sys_code,
|
once_message.sys_code,
|
||||||
json.dumps(conversations, ensure_ascii=False),
|
json.dumps(conversations, ensure_ascii=False),
|
||||||
|
@@ -52,14 +52,14 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
if context:
|
if context:
|
||||||
conversations = json.loads(context)
|
conversations = json.loads(context)
|
||||||
else:
|
else:
|
||||||
chat_history.summary = once_message.get_user_conv().content
|
chat_history.summary = once_message.get_latest_user_message().content
|
||||||
else:
|
else:
|
||||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||||
chat_history.conv_uid = self.chat_seesion_id
|
chat_history.conv_uid = self.chat_seesion_id
|
||||||
chat_history.chat_mode = once_message.chat_mode
|
chat_history.chat_mode = once_message.chat_mode
|
||||||
chat_history.user_name = once_message.user_name
|
chat_history.user_name = once_message.user_name
|
||||||
chat_history.sys_code = once_message.sys_code
|
chat_history.sys_code = once_message.sys_code
|
||||||
chat_history.summary = once_message.get_user_conv().content
|
chat_history.summary = once_message.get_latest_user_message().content
|
||||||
|
|
||||||
conversations.append(_conversation_to_dict(once_message))
|
conversations.append(_conversation_to_dict(once_message))
|
||||||
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
|
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
|
||||||
|
0
dbgpt/storage/chat_history/tests/__init__.py
Normal file
0
dbgpt/storage/chat_history/tests/__init__.py
Normal file
219
dbgpt/storage/chat_history/tests/test_storage_adapter.py
Normal file
219
dbgpt/storage/chat_history/tests/test_storage_adapter.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from dbgpt.util.pagination_utils import PaginationResult
|
||||||
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||||
|
from dbgpt.core.interface.message import StorageConversation, HumanMessage, AIMessage
|
||||||
|
from dbgpt.core.interface.storage import QuerySpec
|
||||||
|
from dbgpt.storage.metadata import db
|
||||||
|
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||||
|
from dbgpt.storage.chat_history.chat_history_db import (
|
||||||
|
ChatHistoryEntity,
|
||||||
|
ChatHistoryMessageEntity,
|
||||||
|
)
|
||||||
|
from dbgpt.storage.chat_history.storage_adapter import (
|
||||||
|
DBStorageConversationItemAdapter,
|
||||||
|
DBMessageStorageItemAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def serializer():
|
||||||
|
return JsonSerializer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_url():
|
||||||
|
"""Use in-memory SQLite database for testing"""
|
||||||
|
return "sqlite:///:memory:"
|
||||||
|
# return "sqlite:///test.db"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_manager(db_url):
|
||||||
|
db.init_db(db_url)
|
||||||
|
db.create_all()
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage_adapter():
|
||||||
|
return DBStorageConversationItemAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage_message_adapter():
|
||||||
|
return DBMessageStorageItemAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conv_storage(db_manager, serializer, storage_adapter):
|
||||||
|
storage = SQLAlchemyStorage(
|
||||||
|
db_manager,
|
||||||
|
ChatHistoryEntity,
|
||||||
|
storage_adapter,
|
||||||
|
serializer,
|
||||||
|
)
|
||||||
|
return storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def message_storage(db_manager, serializer, storage_message_adapter):
|
||||||
|
storage = SQLAlchemyStorage(
|
||||||
|
db_manager,
|
||||||
|
ChatHistoryMessageEntity,
|
||||||
|
storage_message_adapter,
|
||||||
|
serializer,
|
||||||
|
)
|
||||||
|
return storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation(conv_storage, message_storage):
|
||||||
|
return StorageConversation(
|
||||||
|
"conv1",
|
||||||
|
chat_mode="chat_normal",
|
||||||
|
user_name="user1",
|
||||||
|
conv_storage=conv_storage,
|
||||||
|
message_storage=message_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def four_round_conversation(conv_storage, message_storage):
|
||||||
|
conversation = StorageConversation(
|
||||||
|
"conv1",
|
||||||
|
chat_mode="chat_normal",
|
||||||
|
user_name="user1",
|
||||||
|
conv_storage=conv_storage,
|
||||||
|
message_storage=message_storage,
|
||||||
|
)
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is first round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is second round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is third round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is fourth round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_list(request, conv_storage, message_storage):
|
||||||
|
params = request.param if hasattr(request, "param") else {}
|
||||||
|
conv_count = params.get("conv_count", 4)
|
||||||
|
result = []
|
||||||
|
for i in range(conv_count):
|
||||||
|
conversation = StorageConversation(
|
||||||
|
f"conv{i}",
|
||||||
|
chat_mode="chat_normal",
|
||||||
|
user_name="user1",
|
||||||
|
conv_storage=conv_storage,
|
||||||
|
message_storage=message_storage,
|
||||||
|
)
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is first round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is second round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is third round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello, this is fourth round")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
result.append(conversation)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load(
|
||||||
|
conversation: StorageConversation, conv_storage, message_storage
|
||||||
|
):
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
|
||||||
|
saved_conversation = StorageConversation(
|
||||||
|
conv_uid=conversation.conv_uid,
|
||||||
|
conv_storage=conv_storage,
|
||||||
|
message_storage=message_storage,
|
||||||
|
)
|
||||||
|
assert saved_conversation.conv_uid == conversation.conv_uid
|
||||||
|
assert len(saved_conversation.messages) == 2
|
||||||
|
assert isinstance(saved_conversation.messages[0], HumanMessage)
|
||||||
|
assert isinstance(saved_conversation.messages[1], AIMessage)
|
||||||
|
assert saved_conversation.messages[0].content == "hello"
|
||||||
|
assert saved_conversation.messages[0].round_index == 1
|
||||||
|
assert saved_conversation.messages[1].content == "hi"
|
||||||
|
assert saved_conversation.messages[1].round_index == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_message(
|
||||||
|
conversation: StorageConversation, conv_storage, message_storage
|
||||||
|
):
|
||||||
|
conversation.start_new_round()
|
||||||
|
conversation.add_user_message("hello")
|
||||||
|
conversation.add_ai_message("hi")
|
||||||
|
conversation.end_current_round()
|
||||||
|
|
||||||
|
saved_conversation = StorageConversation(
|
||||||
|
conv_uid=conversation.conv_uid,
|
||||||
|
conv_storage=conv_storage,
|
||||||
|
message_storage=message_storage,
|
||||||
|
)
|
||||||
|
assert saved_conversation.conv_uid == conversation.conv_uid
|
||||||
|
assert len(saved_conversation.messages) == 2
|
||||||
|
|
||||||
|
query_spec = QuerySpec(conditions={"conv_uid": conversation.conv_uid})
|
||||||
|
results = conversation.conv_storage.query(query_spec, StorageConversation)
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_query(
|
||||||
|
conversation_list: List[StorageConversation], conv_storage, message_storage
|
||||||
|
):
|
||||||
|
query_spec = QuerySpec(conditions={"user_name": "user1"})
|
||||||
|
results = conv_storage.query(query_spec, StorageConversation)
|
||||||
|
assert len(results) == len(conversation_list)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
assert result.user_name == "user1"
|
||||||
|
assert result.conv_uid == f"conv{i}"
|
||||||
|
saved_conversation = StorageConversation(
|
||||||
|
conv_uid=result.conv_uid,
|
||||||
|
conv_storage=conv_storage,
|
||||||
|
message_storage=message_storage,
|
||||||
|
)
|
||||||
|
assert len(saved_conversation.messages) == 8
|
||||||
|
assert isinstance(saved_conversation.messages[0], HumanMessage)
|
||||||
|
assert isinstance(saved_conversation.messages[1], AIMessage)
|
||||||
|
assert saved_conversation.messages[0].content == "hello, this is first round"
|
||||||
|
assert saved_conversation.messages[1].content == "hi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_with_page(
|
||||||
|
conversation_list: List[StorageConversation], conv_storage, message_storage
|
||||||
|
):
|
||||||
|
query_spec = QuerySpec(conditions={"user_name": "user1"})
|
||||||
|
page_result: PaginationResult = conv_storage.paginate_query(
|
||||||
|
page=1, page_size=2, cls=StorageConversation, spec=query_spec
|
||||||
|
)
|
||||||
|
assert page_result.total_count == len(conversation_list)
|
||||||
|
assert page_result.total_pages == 2
|
||||||
|
assert page_result.page_size == 2
|
||||||
|
assert len(page_result.items) == 2
|
||||||
|
assert page_result.items[0].conv_uid == "conv0"
|
@@ -1 +1,17 @@
|
|||||||
|
from dbgpt.storage.metadata.db_manager import (
|
||||||
|
db,
|
||||||
|
Model,
|
||||||
|
DatabaseManager,
|
||||||
|
create_model,
|
||||||
|
BaseModel,
|
||||||
|
)
|
||||||
from dbgpt.storage.metadata._base_dao import BaseDao
|
from dbgpt.storage.metadata._base_dao import BaseDao
|
||||||
|
|
||||||
|
__ALL__ = [
|
||||||
|
"db",
|
||||||
|
"Model",
|
||||||
|
"DatabaseManager",
|
||||||
|
"create_model",
|
||||||
|
"BaseModel",
|
||||||
|
"BaseDao",
|
||||||
|
]
|
||||||
|
@@ -1,25 +1,72 @@
|
|||||||
from typing import TypeVar, Generic, Any
|
from contextlib import contextmanager
|
||||||
from sqlalchemy.orm import sessionmaker
|
from typing import TypeVar, Generic, Any, Optional
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
from .db_manager import db, DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
class BaseDao(Generic[T]):
|
class BaseDao(Generic[T]):
|
||||||
|
"""The base class for all DAOs.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
class UserDao(BaseDao[User]):
|
||||||
|
def get_user_by_name(self, name: str) -> User:
|
||||||
|
with self.session() as session:
|
||||||
|
return session.query(User).filter(User.name == name).first()
|
||||||
|
|
||||||
|
def get_user_by_id(self, id: int) -> User:
|
||||||
|
with self.session() as session:
|
||||||
|
return User.get(id)
|
||||||
|
|
||||||
|
def create_user(self, name: str) -> User:
|
||||||
|
return User.create(**{"name": name})
|
||||||
|
Args:
|
||||||
|
db_manager (DatabaseManager, optional): The database manager. Defaults to None.
|
||||||
|
If None, the default database manager(db) will be used.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
orm_base=None,
|
db_manager: Optional[DatabaseManager] = None,
|
||||||
database: str = None,
|
|
||||||
db_engine: Any = None,
|
|
||||||
session: Any = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
self._db_manager = db_manager or db
|
||||||
self._orm_base = orm_base
|
|
||||||
self._database = database
|
|
||||||
|
|
||||||
self._db_engine = db_engine
|
def get_raw_session(self) -> Session:
|
||||||
self._session = session
|
"""Get a raw session object.
|
||||||
|
|
||||||
def get_session(self):
|
Your should commit or rollback the session manually.
|
||||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine)
|
We suggest you use :meth:`session` instead.
|
||||||
session = Session()
|
|
||||||
return session
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
user = User(name="Edward Snowden")
|
||||||
|
session = self.get_raw_session()
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
"""
|
||||||
|
return self._db_manager._session()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session(self) -> Session:
|
||||||
|
"""Provide a transactional scope around a series of operations.
|
||||||
|
|
||||||
|
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
with self.session() as session:
|
||||||
|
session.query(User).filter(User.name == 'Edward Snowden').first()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session: A session object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Any exception will be raised.
|
||||||
|
"""
|
||||||
|
with self._db_manager.session() as session:
|
||||||
|
yield session
|
||||||
|
432
dbgpt/storage/metadata/db_manager.py
Normal file
432
dbgpt/storage/metadata/db_manager.py
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
|
||||||
|
import logging
|
||||||
|
from sqlalchemy import create_engine, URL, Engine
|
||||||
|
from sqlalchemy import orm, inspect, MetaData
|
||||||
|
from sqlalchemy.orm import (
|
||||||
|
scoped_session,
|
||||||
|
sessionmaker,
|
||||||
|
Session,
|
||||||
|
declarative_base,
|
||||||
|
DeclarativeMeta,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm.session import _PKIdentityArgument
|
||||||
|
from sqlalchemy.orm.exc import UnmappedClassError
|
||||||
|
|
||||||
|
from sqlalchemy.pool import QueuePool
|
||||||
|
from dbgpt.util.string_utils import _to_str
|
||||||
|
from dbgpt.util.pagination_utils import PaginationResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
T = TypeVar("T", bound="BaseModel")
|
||||||
|
|
||||||
|
|
||||||
|
class _QueryObject:
|
||||||
|
"""The query object."""
|
||||||
|
|
||||||
|
def __init__(self, db_manager: "DatabaseManager"):
|
||||||
|
self._db_manager = db_manager
|
||||||
|
|
||||||
|
def __get__(self, obj, type):
|
||||||
|
try:
|
||||||
|
mapper = orm.class_mapper(type)
|
||||||
|
if mapper:
|
||||||
|
return type.query_class(mapper, session=self._db_manager._session())
|
||||||
|
except UnmappedClassError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseQuery(orm.Query):
|
||||||
|
def paginate_query(
|
||||||
|
self, page: Optional[int] = 1, per_page: Optional[int] = 20
|
||||||
|
) -> PaginationResult:
|
||||||
|
"""Paginate the query.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from dbgpt.storage.metadata import db, Model
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
fullname = Column(String(50))
|
||||||
|
|
||||||
|
with db.session() as session:
|
||||||
|
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
||||||
|
print(pagination)
|
||||||
|
|
||||||
|
# Or you can use the query object
|
||||||
|
with db.session() as session:
|
||||||
|
pagination = User.query.paginate_query(page=1, page_size=10)
|
||||||
|
print(pagination)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page (Optional[int], optional): The page number. Defaults to 1.
|
||||||
|
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
|
||||||
|
Returns:
|
||||||
|
PaginationResult: The pagination result.
|
||||||
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise ValueError("Page must be greater than 0")
|
||||||
|
if per_page < 0:
|
||||||
|
raise ValueError("Per page must be greater than 0")
|
||||||
|
items = self.limit(per_page).offset((page - 1) * per_page).all()
|
||||||
|
total = self.order_by(None).count()
|
||||||
|
total_pages = (total - 1) // per_page + 1
|
||||||
|
return PaginationResult(
|
||||||
|
items=items,
|
||||||
|
total_count=total,
|
||||||
|
total_pages=total_pages,
|
||||||
|
page=page,
|
||||||
|
page_size=per_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _Model:
|
||||||
|
"""Base class for SQLAlchemy declarative base model.
|
||||||
|
|
||||||
|
With this class, we can use the query object to query the database.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
from dbgpt.storage.metadata import db, Model
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
fullname = Column(String(50))
|
||||||
|
|
||||||
|
with db.session() as session:
|
||||||
|
# User is an instance of _Model, and we can use the query object to query the database.
|
||||||
|
User.query.filter(User.name == "test").all()
|
||||||
|
"""
|
||||||
|
|
||||||
|
query_class = None
|
||||||
|
query: Optional[BaseQuery] = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
identity = inspect(self).identity
|
||||||
|
if identity is None:
|
||||||
|
pk = "(transient {0})".format(id(self))
|
||||||
|
else:
|
||||||
|
pk = ", ".join(_to_str(value) for value in identity)
|
||||||
|
return "<{0} {1}>".format(type(self).__name__, pk)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
"""The database manager.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
from urllib.parse import quote_plus as urlquote, quote
|
||||||
|
from dbgpt.storage.metadata import DatabaseManager, create_model
|
||||||
|
db = DatabaseManager()
|
||||||
|
# Use sqlite with memory storage.
|
||||||
|
url = f"sqlite:///:memory:"
|
||||||
|
engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
|
||||||
|
db.init_db(url, engine_args=engine_args)
|
||||||
|
|
||||||
|
Model = create_model(db)
|
||||||
|
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
fullname = Column(String(50))
|
||||||
|
|
||||||
|
with db.session() as session:
|
||||||
|
session.add(User(name="test", fullname="test"))
|
||||||
|
# db will commit the session automatically default.
|
||||||
|
# session.commit()
|
||||||
|
print(User.query.filter(User.name == "test").all())
|
||||||
|
|
||||||
|
|
||||||
|
# Use CURDMixin APIs to create, update, delete, query the database.
|
||||||
|
with db.session() as session:
|
||||||
|
User.create(**{"name": "test1", "fullname": "test1"})
|
||||||
|
User.create(**{"name": "test2", "fullname": "test1"})
|
||||||
|
users = User.all()
|
||||||
|
print(users)
|
||||||
|
user = users[0]
|
||||||
|
user.update(**{"name": "test1_1111"})
|
||||||
|
user2 = users[1]
|
||||||
|
# Update user2 by save
|
||||||
|
user2.name = "test2_1111"
|
||||||
|
user2.save()
|
||||||
|
# Delete user2
|
||||||
|
user2.delete()
|
||||||
|
"""
|
||||||
|
|
||||||
|
Query = BaseQuery
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._db_url = None
|
||||||
|
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
|
||||||
|
self._engine: Optional[Engine] = None
|
||||||
|
self._session: Optional[scoped_session] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def Model(self) -> _Model:
|
||||||
|
"""Get the declarative base."""
|
||||||
|
return self._base
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> MetaData:
|
||||||
|
"""Get the metadata."""
|
||||||
|
return self.Model.metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self):
|
||||||
|
"""Get the engine.""" ""
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session(self) -> Session:
|
||||||
|
"""Get the session with context manager.
|
||||||
|
|
||||||
|
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with db.session() as session:
|
||||||
|
>>> session.query(...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session: The session.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: The database manager is not initialized.
|
||||||
|
Exception: Any exception.
|
||||||
|
"""
|
||||||
|
if not self._session:
|
||||||
|
raise RuntimeError("The database manager is not initialized.")
|
||||||
|
session = self._session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def _make_declarative_base(
|
||||||
|
self, model: Union[Type[DeclarativeMeta], Type[_Model]]
|
||||||
|
) -> DeclarativeMeta:
|
||||||
|
"""Make the declarative base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base (DeclarativeMeta): The base class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DeclarativeMeta: The declarative base.
|
||||||
|
"""
|
||||||
|
if not isinstance(model, DeclarativeMeta):
|
||||||
|
model = declarative_base(cls=model, name="Model")
|
||||||
|
if not getattr(model, "query_class", None):
|
||||||
|
model.query_class = self.Query
|
||||||
|
model.query = _QueryObject(self)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def init_db(
|
||||||
|
self,
|
||||||
|
db_url: Union[str, URL],
|
||||||
|
engine_args: Optional[Dict] = None,
|
||||||
|
base: Optional[DeclarativeMeta] = None,
|
||||||
|
query_class=BaseQuery,
|
||||||
|
):
|
||||||
|
"""Initialize the database manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_url (Union[str, URL]): The database url.
|
||||||
|
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||||
|
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||||
|
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||||
|
"""
|
||||||
|
self._db_url = db_url
|
||||||
|
if query_class is not None:
|
||||||
|
self.Query = query_class
|
||||||
|
if base is not None:
|
||||||
|
self._base = base
|
||||||
|
if not hasattr(base, "query"):
|
||||||
|
base.query = _QueryObject(self)
|
||||||
|
if not getattr(base, "query_class", None):
|
||||||
|
base.query_class = self.Query
|
||||||
|
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||||
|
session_factory = sessionmaker(bind=self._engine)
|
||||||
|
self._session = scoped_session(session_factory)
|
||||||
|
self._base.metadata.bind = self._engine
|
||||||
|
|
||||||
|
def init_default_db(
|
||||||
|
self,
|
||||||
|
sqlite_path: str,
|
||||||
|
engine_args: Optional[Dict] = None,
|
||||||
|
base: Optional[DeclarativeMeta] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the database manager with default config.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> db.init_default_db(sqlite_path)
|
||||||
|
>>> with db.session() as session:
|
||||||
|
>>> session.query(...)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sqlite_path (str): The sqlite path.
|
||||||
|
engine_args (Optional[Dict], optional): The engine arguments.
|
||||||
|
Defaults to None, if None, we will use connection pool.
|
||||||
|
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not engine_args:
|
||||||
|
engine_args = {}
|
||||||
|
# Pool class
|
||||||
|
engine_args["poolclass"] = QueuePool
|
||||||
|
# The number of connections to keep open inside the connection pool.
|
||||||
|
engine_args["pool_size"] = 10
|
||||||
|
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
|
||||||
|
# pool_size).
|
||||||
|
engine_args["max_overflow"] = 20
|
||||||
|
# The number of seconds to wait before giving up on getting a connection from the pool.
|
||||||
|
engine_args["pool_timeout"] = 30
|
||||||
|
# Recycle the connection if it has been idle for this many seconds.
|
||||||
|
engine_args["pool_recycle"] = 3600
|
||||||
|
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
|
||||||
|
engine_args["pool_pre_ping"] = True
|
||||||
|
|
||||||
|
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
|
||||||
|
|
||||||
|
def create_all(self):
|
||||||
|
self.Model.metadata.create_all(self._engine)
|
||||||
|
|
||||||
|
|
||||||
|
db = DatabaseManager()
|
||||||
|
"""The global database manager.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from dbgpt.storage.metadata import db
|
||||||
|
>>> sqlite_path = "/tmp/dbgpt.db"
|
||||||
|
>>> db.init_default_db(sqlite_path)
|
||||||
|
>>> with db.session() as session:
|
||||||
|
>>> session.query(...)
|
||||||
|
|
||||||
|
>>> from dbgpt.storage.metadata import db, Model
|
||||||
|
>>> from urllib.parse import quote_plus as urlquote, quote
|
||||||
|
>>> db_name = "dbgpt"
|
||||||
|
>>> db_host = "localhost"
|
||||||
|
>>> db_port = 3306
|
||||||
|
>>> user = "root"
|
||||||
|
>>> password = "123456"
|
||||||
|
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
|
||||||
|
>>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
|
||||||
|
>>> db.init_db(url, engine_args=engine_args)
|
||||||
|
>>> class User(Model):
|
||||||
|
>>> __tablename__ = "user"
|
||||||
|
>>> id = Column(Integer, primary_key=True)
|
||||||
|
>>> name = Column(String(50))
|
||||||
|
>>> fullname = Column(String(50))
|
||||||
|
>>> with db.session() as session:
|
||||||
|
>>> session.add(User(name="test", fullname="test"))
|
||||||
|
>>> session.commit()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCRUDMixin(Generic[T]):
|
||||||
|
"""The base CRUD mixin."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls: Type[T], **kwargs) -> T:
|
||||||
|
instance = cls(**kwargs)
|
||||||
|
return instance.save()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all(cls: Type[T]) -> List[T]:
|
||||||
|
return cls.query.all()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||||
|
"""Get a record by its primary key identifier."""
|
||||||
|
|
||||||
|
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
|
||||||
|
"""Update specific fields of a record."""
|
||||||
|
for attr, value in kwargs.items():
|
||||||
|
setattr(self, attr, value)
|
||||||
|
return commit and self.save() or self
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||||
|
"""Save the record."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||||
|
"""Remove the record from the database."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||||
|
|
||||||
|
"""The base model class that includes CRUD convenience methods."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||||
|
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
|
||||||
|
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||||
|
"""Get a record by its primary key identifier."""
|
||||||
|
return db_manager._session().get(cls, ident)
|
||||||
|
|
||||||
|
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||||
|
"""Save the record."""
|
||||||
|
session = db_manager._session()
|
||||||
|
session.add(self)
|
||||||
|
if commit:
|
||||||
|
session.commit()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||||
|
"""Remove the record from the database."""
|
||||||
|
session = db_manager._session()
|
||||||
|
session.delete(self)
|
||||||
|
return commit and session.commit()
|
||||||
|
|
||||||
|
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||||
|
"""Base model class that includes CRUD convenience methods."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
return _NewModel
|
||||||
|
|
||||||
|
|
||||||
|
Model = create_model(db)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_db(
|
||||||
|
db_url: Union[str, URL],
|
||||||
|
db_name: str,
|
||||||
|
engine_args: Optional[Dict] = None,
|
||||||
|
base: Optional[DeclarativeMeta] = None,
|
||||||
|
try_to_create_db: Optional[bool] = False,
|
||||||
|
) -> DatabaseManager:
|
||||||
|
"""Initialize the database manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_url (Union[str, URL]): The database url.
|
||||||
|
db_name (str): The database name.
|
||||||
|
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||||
|
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||||
|
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||||
|
Returns:
|
||||||
|
DatabaseManager: The database manager.
|
||||||
|
"""
|
||||||
|
db.init_db(db_url, engine_args, base)
|
||||||
|
if try_to_create_db:
|
||||||
|
try:
|
||||||
|
db.create_all()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create database {db_name}: {e}")
|
||||||
|
return db
|
128
dbgpt/storage/metadata/db_storage.py
Normal file
128
dbgpt/storage/metadata/db_storage.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from typing import Type, List, Optional, Union, Dict
|
||||||
|
from dbgpt.core import Serializer
|
||||||
|
from dbgpt.core.interface.storage import (
|
||||||
|
StorageInterface,
|
||||||
|
QuerySpec,
|
||||||
|
ResourceIdentifier,
|
||||||
|
StorageItemAdapter,
|
||||||
|
T,
|
||||||
|
)
|
||||||
|
from sqlalchemy import URL
|
||||||
|
from sqlalchemy.orm import Session, DeclarativeMeta
|
||||||
|
|
||||||
|
from .db_manager import BaseModel, DatabaseManager, BaseQuery
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_public_properties(src: BaseModel, dest: BaseModel):
|
||||||
|
"""Simple copy public properties from src to dest"""
|
||||||
|
for column in src.__table__.columns:
|
||||||
|
if column.name != "id":
|
||||||
|
setattr(dest, column.name, getattr(src, column.name))
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_url_or_db: Union[str, URL, DatabaseManager],
|
||||||
|
model_class: Type[BaseModel],
|
||||||
|
adapter: StorageItemAdapter[T, BaseModel],
|
||||||
|
serializer: Optional[Serializer] = None,
|
||||||
|
engine_args: Optional[Dict] = None,
|
||||||
|
base: Optional[DeclarativeMeta] = None,
|
||||||
|
query_class=BaseQuery,
|
||||||
|
):
|
||||||
|
super().__init__(serializer=serializer, adapter=adapter)
|
||||||
|
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||||
|
db_manager = DatabaseManager()
|
||||||
|
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
|
||||||
|
self.db_manager = db_manager
|
||||||
|
elif isinstance(db_url_or_db, DatabaseManager):
|
||||||
|
self.db_manager = db_url_or_db
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||||
|
)
|
||||||
|
self._model_class = model_class
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session(self) -> Session:
|
||||||
|
with self.db_manager.session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
def save(self, data: T) -> None:
|
||||||
|
with self.session() as session:
|
||||||
|
model_instance = self.adapter.to_storage_format(data)
|
||||||
|
session.add(model_instance)
|
||||||
|
|
||||||
|
def update(self, data: T) -> None:
|
||||||
|
with self.session() as session:
|
||||||
|
model_instance = self.adapter.to_storage_format(data)
|
||||||
|
session.merge(model_instance)
|
||||||
|
|
||||||
|
def save_or_update(self, data: T) -> None:
|
||||||
|
with self.session() as session:
|
||||||
|
query = self.adapter.get_query_for_identifier(
|
||||||
|
self._model_class, data.identifier, session=session
|
||||||
|
)
|
||||||
|
model_instance = query.with_session(session).first()
|
||||||
|
if model_instance:
|
||||||
|
new_instance = self.adapter.to_storage_format(data)
|
||||||
|
_copy_public_properties(new_instance, model_instance)
|
||||||
|
session.merge(model_instance)
|
||||||
|
return
|
||||||
|
self.save(data)
|
||||||
|
|
||||||
|
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||||
|
with self.session() as session:
|
||||||
|
query = self.adapter.get_query_for_identifier(
|
||||||
|
self._model_class, resource_id, session=session
|
||||||
|
)
|
||||||
|
model_instance = query.with_session(session).first()
|
||||||
|
if model_instance:
|
||||||
|
return self.adapter.from_storage_format(model_instance)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||||
|
with self.session() as session:
|
||||||
|
query = self.adapter.get_query_for_identifier(
|
||||||
|
self._model_class, resource_id, session=session
|
||||||
|
)
|
||||||
|
model_instance = query.with_session(session).first()
|
||||||
|
if model_instance:
|
||||||
|
session.delete(model_instance)
|
||||||
|
|
||||||
|
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||||
|
"""Query data from the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (QuerySpec): The query specification
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
"""
|
||||||
|
with self.session() as session:
|
||||||
|
query = session.query(self._model_class)
|
||||||
|
for key, value in spec.conditions.items():
|
||||||
|
query = query.filter(getattr(self._model_class, key) == value)
|
||||||
|
if spec.limit is not None:
|
||||||
|
query = query.limit(spec.limit)
|
||||||
|
if spec.offset is not None:
|
||||||
|
query = query.offset(spec.offset)
|
||||||
|
model_instances = query.all()
|
||||||
|
return [
|
||||||
|
self.adapter.from_storage_format(instance)
|
||||||
|
for instance in model_instances
|
||||||
|
]
|
||||||
|
|
||||||
|
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||||
|
"""Count the number of data in the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (QuerySpec): The query specification
|
||||||
|
cls (Type[T]): The type of the data
|
||||||
|
"""
|
||||||
|
with self.session() as session:
|
||||||
|
query = session.query(self._model_class)
|
||||||
|
for key, value in spec.conditions.items():
|
||||||
|
query = query.filter(getattr(self._model_class, key) == value)
|
||||||
|
return query.count()
|
@@ -1,94 +0,0 @@
|
|||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, DDL
|
|
||||||
from sqlalchemy.exc import OperationalError
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from alembic import command
|
|
||||||
from alembic.config import Config as AlembicConfig
|
|
||||||
from urllib.parse import quote
|
|
||||||
from dbgpt._private.config import Config
|
|
||||||
from dbgpt.configs.model_config import PILOT_PATH
|
|
||||||
from urllib.parse import quote_plus as urlquote
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
# DB-GPT metadata database config, now support mysql and sqlite
|
|
||||||
CFG = Config()
|
|
||||||
default_db_path = os.path.join(PILOT_PATH, "meta_data")
|
|
||||||
|
|
||||||
os.makedirs(default_db_path, exist_ok=True)
|
|
||||||
|
|
||||||
# Meta Info
|
|
||||||
META_DATA_DATABASE = CFG.LOCAL_DB_NAME
|
|
||||||
db_name = META_DATA_DATABASE
|
|
||||||
db_path = default_db_path + f"/{db_name}.db"
|
|
||||||
connection = sqlite3.connect(db_path)
|
|
||||||
|
|
||||||
|
|
||||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
|
||||||
engine_temp = create_engine(
|
|
||||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
|
|
||||||
)
|
|
||||||
# check and auto create mysqldatabase
|
|
||||||
try:
|
|
||||||
# try to connect
|
|
||||||
with engine_temp.connect() as conn:
|
|
||||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
|
||||||
conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}"))
|
|
||||||
print(f"Already connect '{db_name}'")
|
|
||||||
|
|
||||||
except OperationalError as e:
|
|
||||||
# if connect failed, create dbgpt database
|
|
||||||
logger.error(f"{db_name} not connect success!")
|
|
||||||
|
|
||||||
engine = create_engine(
|
|
||||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
engine = create_engine(f"sqlite:///{db_path}")
|
|
||||||
|
|
||||||
|
|
||||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
session = Session()
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
# Base.metadata.create_all()
|
|
||||||
|
|
||||||
alembic_ini_path = default_db_path + "/alembic.ini"
|
|
||||||
alembic_cfg = AlembicConfig(alembic_ini_path)
|
|
||||||
|
|
||||||
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
|
|
||||||
|
|
||||||
os.makedirs(default_db_path + "/alembic", exist_ok=True)
|
|
||||||
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
|
|
||||||
|
|
||||||
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
|
|
||||||
|
|
||||||
alembic_cfg.attributes["target_metadata"] = Base.metadata
|
|
||||||
alembic_cfg.attributes["session"] = session
|
|
||||||
|
|
||||||
|
|
||||||
def ddl_init_and_upgrade(disable_alembic_upgrade: bool):
|
|
||||||
"""Initialize and upgrade database metadata
|
|
||||||
|
|
||||||
Args:
|
|
||||||
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
|
|
||||||
"""
|
|
||||||
if disable_alembic_upgrade:
|
|
||||||
logger.info(
|
|
||||||
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
alembic_cfg.attributes["connection"] = connection
|
|
||||||
heads = command.heads(alembic_cfg)
|
|
||||||
print("heads:" + str(heads))
|
|
||||||
|
|
||||||
command.revision(alembic_cfg, "dbgpt ddl upate", True)
|
|
||||||
command.upgrade(alembic_cfg, "head")
|
|
0
dbgpt/storage/metadata/tests/__init__.py
Normal file
0
dbgpt/storage/metadata/tests/__init__.py
Normal file
129
dbgpt/storage/metadata/tests/test_db_manager.py
Normal file
129
dbgpt/storage/metadata/tests/test_db_manager.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import pytest
|
||||||
|
from typing import Type
|
||||||
|
from dbgpt.storage.metadata.db_manager import (
|
||||||
|
DatabaseManager,
|
||||||
|
PaginationResult,
|
||||||
|
create_model,
|
||||||
|
BaseModel,
|
||||||
|
)
|
||||||
|
from sqlalchemy import Column, Integer, String
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db():
|
||||||
|
db = DatabaseManager()
|
||||||
|
db.init_db("sqlite:///:memory:")
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def Model(db):
|
||||||
|
return create_model(db)
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
|
assert db.engine is not None
|
||||||
|
assert db.session is not None
|
||||||
|
|
||||||
|
with db.session() as session:
|
||||||
|
assert session is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
|
assert db.metadata.tables == {}
|
||||||
|
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
|
||||||
|
db.create_all()
|
||||||
|
assert list(db.metadata.tables.keys())[0] == "user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
# Create
|
||||||
|
with db.session() as session:
|
||||||
|
user = User.create(name="John Doe")
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Read
|
||||||
|
with db.session() as session:
|
||||||
|
user = session.query(User).filter_by(name="John Doe").first()
|
||||||
|
assert user is not None
|
||||||
|
|
||||||
|
# Update
|
||||||
|
with db.session() as session:
|
||||||
|
user = session.query(User).filter_by(name="John Doe").first()
|
||||||
|
user.update(name="Jane Doe")
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
with db.session() as session:
|
||||||
|
user = session.query(User).filter_by(name="Jane Doe").first()
|
||||||
|
user.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
# Create
|
||||||
|
user = User.create(name="John Doe")
|
||||||
|
assert User.get(user.id) is not None
|
||||||
|
users = User.all()
|
||||||
|
assert len(users) == 1
|
||||||
|
|
||||||
|
# Update
|
||||||
|
user.update(name="Bob Doe")
|
||||||
|
assert User.get(user.id).name == "Bob Doe"
|
||||||
|
|
||||||
|
user = User.get(user.id)
|
||||||
|
user.delete()
|
||||||
|
assert User.get(user.id) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
# 添加数据
|
||||||
|
with db.session() as session:
|
||||||
|
for i in range(30):
|
||||||
|
user = User(name=f"User {i}")
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
users_page_1 = User.query.paginate_query(page=1, per_page=10)
|
||||||
|
assert len(users_page_1.items) == 10
|
||||||
|
assert users_page_1.total_pages == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
|
class User(Model):
|
||||||
|
__tablename__ = "user"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(50))
|
||||||
|
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
User.query.paginate_query(page=0, per_page=10)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
User.query.paginate_query(page=1, per_page=-1)
|
173
dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
Normal file
173
dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
from typing import Dict, Type
|
||||||
|
from sqlalchemy.orm import declarative_base, Session
|
||||||
|
from sqlalchemy import Column, Integer, String
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from dbgpt.core.interface.storage import (
|
||||||
|
StorageItem,
|
||||||
|
ResourceIdentifier,
|
||||||
|
StorageItemAdapter,
|
||||||
|
QuerySpec,
|
||||||
|
)
|
||||||
|
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||||
|
|
||||||
|
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
|
||||||
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||||
|
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class MockModel(Base):
|
||||||
|
"""The SQLAlchemy model for the mock data."""
|
||||||
|
|
||||||
|
__tablename__ = "mock_data"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
data = Column(String)
|
||||||
|
|
||||||
|
|
||||||
|
class MockStorageItem(StorageItem):
|
||||||
|
"""The mock storage item."""
|
||||||
|
|
||||||
|
def merge(self, other: "StorageItem") -> None:
|
||||||
|
if not isinstance(other, MockStorageItem):
|
||||||
|
raise ValueError("other must be a MockStorageItem")
|
||||||
|
self.data = other.data
|
||||||
|
|
||||||
|
def __init__(self, identifier: ResourceIdentifier, data: str):
|
||||||
|
self._identifier = identifier
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self) -> ResourceIdentifier:
|
||||||
|
return self._identifier
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return {"identifier": self._identifier, "data": self.data}
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
return str(self.data).encode()
|
||||||
|
|
||||||
|
|
||||||
|
class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]):
|
||||||
|
"""The adapter for the mock storage item."""
|
||||||
|
|
||||||
|
def to_storage_format(self, item: MockStorageItem) -> MockModel:
|
||||||
|
return MockModel(id=int(item.identifier.str_identifier), data=item.data)
|
||||||
|
|
||||||
|
def from_storage_format(self, model: MockModel) -> MockStorageItem:
|
||||||
|
return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data)
|
||||||
|
|
||||||
|
def get_query_for_identifier(
|
||||||
|
self,
|
||||||
|
storage_format: Type[MockModel],
|
||||||
|
resource_id: ResourceIdentifier,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
session: Session = kwargs.get("session")
|
||||||
|
if session is None:
|
||||||
|
raise ValueError("session is required for this adapter")
|
||||||
|
return session.query(storage_format).filter(
|
||||||
|
storage_format.id == int(resource_id.str_identifier)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def serializer():
|
||||||
|
return JsonSerializer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_url():
|
||||||
|
"""Use in-memory SQLite database for testing"""
|
||||||
|
return "sqlite:///:memory:"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sqlalchemy_storage(db_url, serializer):
|
||||||
|
adapter = MockStorageItemAdapter()
|
||||||
|
storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base)
|
||||||
|
Base.metadata.create_all(storage.db_manager.engine)
|
||||||
|
return storage
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load(sqlalchemy_storage):
|
||||||
|
item = MockStorageItem(MockResourceIdentifier("1"), "test_data")
|
||||||
|
|
||||||
|
sqlalchemy_storage.save(item)
|
||||||
|
|
||||||
|
loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem)
|
||||||
|
assert loaded_item.data == "test_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(sqlalchemy_storage):
|
||||||
|
resource_id = MockResourceIdentifier("1")
|
||||||
|
|
||||||
|
sqlalchemy_storage.delete(resource_id)
|
||||||
|
# Make sure the item is deleted
|
||||||
|
assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_with_various_conditions(sqlalchemy_storage):
|
||||||
|
# Add multiple items for testing
|
||||||
|
for i in range(5):
|
||||||
|
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||||
|
sqlalchemy_storage.save(item)
|
||||||
|
|
||||||
|
# Test query with single condition
|
||||||
|
query_spec = QuerySpec(conditions={"data": "test_data_2"})
|
||||||
|
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].data == "test_data_2"
|
||||||
|
|
||||||
|
# Test not existing condition
|
||||||
|
query_spec = QuerySpec(conditions={"data": "nonexistent"})
|
||||||
|
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
# Test query with multiple conditions
|
||||||
|
query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"})
|
||||||
|
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_nonexistent_item(sqlalchemy_storage):
|
||||||
|
query_spec = QuerySpec(conditions={"data": "nonexistent"})
|
||||||
|
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_items(sqlalchemy_storage):
|
||||||
|
for i in range(5):
|
||||||
|
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||||
|
sqlalchemy_storage.save(item)
|
||||||
|
|
||||||
|
# Test count without conditions
|
||||||
|
query_spec = QuerySpec(conditions={})
|
||||||
|
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
|
||||||
|
assert total_count == 5
|
||||||
|
|
||||||
|
# Test count with conditions
|
||||||
|
query_spec = QuerySpec(conditions={"data": "test_data_2"})
|
||||||
|
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
|
||||||
|
assert total_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_paginate_query(sqlalchemy_storage):
|
||||||
|
for i in range(10):
|
||||||
|
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||||
|
sqlalchemy_storage.save(item)
|
||||||
|
|
||||||
|
page_size = 3
|
||||||
|
page_number = 2
|
||||||
|
|
||||||
|
query_spec = QuerySpec(conditions={})
|
||||||
|
page_result = sqlalchemy_storage.paginate_query(
|
||||||
|
page_number, page_size, MockStorageItem, query_spec
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(page_result.items) == page_size
|
||||||
|
assert page_result.page == page_number
|
||||||
|
assert page_result.total_pages == 4
|
||||||
|
assert page_result.total_count == 10
|
219
dbgpt/util/_db_migration_utils.py
Normal file
219
dbgpt/util/_db_migration_utils.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from sqlalchemy import Engine, text
|
||||||
|
from sqlalchemy.orm import Session, DeclarativeMeta
|
||||||
|
from alembic import command
|
||||||
|
from alembic.util.exc import CommandError
|
||||||
|
from alembic.config import Config as AlembicConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_alembic_config(
|
||||||
|
alembic_root_path: str,
|
||||||
|
engine: Engine,
|
||||||
|
base: DeclarativeMeta,
|
||||||
|
session: Session,
|
||||||
|
alembic_ini_path: Optional[str] = None,
|
||||||
|
script_location: Optional[str] = None,
|
||||||
|
) -> AlembicConfig:
|
||||||
|
"""Create alembic config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alembic_root_path: alembic root path
|
||||||
|
engine: sqlalchemy engine
|
||||||
|
base: sqlalchemy base
|
||||||
|
session: sqlalchemy session
|
||||||
|
alembic_ini_path (Optional[str]): alembic ini path
|
||||||
|
script_location (Optional[str]): alembic script location
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
alembic config
|
||||||
|
"""
|
||||||
|
alembic_ini_path = alembic_ini_path or os.path.join(
|
||||||
|
alembic_root_path, "alembic.ini"
|
||||||
|
)
|
||||||
|
alembic_cfg = AlembicConfig(alembic_ini_path)
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
|
||||||
|
script_location = script_location or os.path.join(alembic_root_path, "alembic")
|
||||||
|
versions_dir = os.path.join(script_location, "versions")
|
||||||
|
|
||||||
|
os.makedirs(script_location, exist_ok=True)
|
||||||
|
os.makedirs(versions_dir, exist_ok=True)
|
||||||
|
|
||||||
|
alembic_cfg.set_main_option("script_location", script_location)
|
||||||
|
|
||||||
|
alembic_cfg.attributes["target_metadata"] = base.metadata
|
||||||
|
alembic_cfg.attributes["session"] = session
|
||||||
|
return alembic_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def create_migration_script(
|
||||||
|
alembic_cfg: AlembicConfig, engine: Engine, message: str = "New migration"
|
||||||
|
) -> None:
|
||||||
|
"""Create migration script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alembic_cfg: alembic config
|
||||||
|
engine: sqlalchemy engine
|
||||||
|
message: migration message
|
||||||
|
|
||||||
|
"""
|
||||||
|
with engine.connect() as connection:
|
||||||
|
alembic_cfg.attributes["connection"] = connection
|
||||||
|
command.revision(alembic_cfg, message, autogenerate=True)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade_database(
|
||||||
|
alembic_cfg: AlembicConfig, engine: Engine, target_version: str = "head"
|
||||||
|
) -> None:
|
||||||
|
"""Upgrade database to target version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alembic_cfg: alembic config
|
||||||
|
engine: sqlalchemy engine
|
||||||
|
target_version: target version, default is head(latest version)
|
||||||
|
"""
|
||||||
|
with engine.connect() as connection:
|
||||||
|
alembic_cfg.attributes["connection"] = connection
|
||||||
|
# Will create tables if not exists
|
||||||
|
command.upgrade(alembic_cfg, target_version)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade_database(
|
||||||
|
alembic_cfg: AlembicConfig, engine: Engine, revision: str = "-1"
|
||||||
|
):
|
||||||
|
"""Downgrade the database by one revision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alembic_cfg: Alembic configuration object.
|
||||||
|
engine: SQLAlchemy engine instance.
|
||||||
|
revision: Revision identifier, default is "-1" which means one revision back.
|
||||||
|
"""
|
||||||
|
with engine.connect() as connection:
|
||||||
|
alembic_cfg.attributes["connection"] = connection
|
||||||
|
command.downgrade(alembic_cfg, revision)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_alembic_migration(alembic_cfg: AlembicConfig, engine: Engine) -> None:
|
||||||
|
"""Clean Alembic migration scripts and history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alembic_cfg: Alembic config object
|
||||||
|
engine: SQLAlchemy engine instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Get migration script location
|
||||||
|
script_location = alembic_cfg.get_main_option("script_location")
|
||||||
|
print(f"Delete migration script location: {script_location}")
|
||||||
|
|
||||||
|
# Delete all migration script files
|
||||||
|
for file in os.listdir(script_location):
|
||||||
|
if file.startswith("versions"):
|
||||||
|
filepath = os.path.join(script_location, file)
|
||||||
|
print(f"Delete migration script file: {filepath}")
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
os.remove(filepath)
|
||||||
|
else:
|
||||||
|
shutil.rmtree(filepath, ignore_errors=True)
|
||||||
|
|
||||||
|
# Delete Alembic version table if exists
|
||||||
|
version_table = alembic_cfg.get_main_option("version_table") or "alembic_version"
|
||||||
|
if version_table:
|
||||||
|
with engine.connect() as connection:
|
||||||
|
print(f"Delete Alembic version table: {version_table}")
|
||||||
|
connection.execute(text(f"DROP TABLE IF EXISTS {version_table}"))
|
||||||
|
|
||||||
|
print("Cleaned Alembic migration scripts and history")
|
||||||
|
|
||||||
|
|
||||||
|
_MIGRATION_SOLUTION = """
|
||||||
|
**Solution 1:**
|
||||||
|
|
||||||
|
Run the following command to upgrade the database.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution 2:**
|
||||||
|
|
||||||
|
Run the following command to clean the migration script and migration history.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration clean -y
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution 3:**
|
||||||
|
|
||||||
|
If you have already run the above command, but the error still exists,
|
||||||
|
you can try the following command to clean the migration script, migration history and your data.
|
||||||
|
warning: This command will delete all your data!!! Please use it with caution.
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration clean --drop_all_tables -y --confirm_drop_all_tables
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```commandline
|
||||||
|
rm -rf pilot/meta_data/alembic/versions/*
|
||||||
|
rm -rf pilot/meta_data/alembic/dbgpt.db
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _ddl_init_and_upgrade(
|
||||||
|
default_meta_data_path: str,
|
||||||
|
disable_alembic_upgrade: bool,
|
||||||
|
alembic_ini_path: Optional[str] = None,
|
||||||
|
script_location: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize and upgrade database metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_meta_data_path (str): default meta data path
|
||||||
|
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
|
||||||
|
alembic_ini_path (Optional[str]): alembic ini path
|
||||||
|
script_location (Optional[str]): alembic script location
|
||||||
|
"""
|
||||||
|
if disable_alembic_upgrade:
|
||||||
|
logger.info(
|
||||||
|
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
warn_msg = (
|
||||||
|
"Initialize and upgrade database metadata with alembic, "
|
||||||
|
"just run this in your development environment, if you deploy this in production environment, "
|
||||||
|
"please run webserver with --disable_alembic_upgrade(`python dbgpt/app/dbgpt_server.py "
|
||||||
|
"--disable_alembic_upgrade`).\n"
|
||||||
|
"we suggest you to use `dbgpt db migration` to initialize and upgrade database metadata with alembic, "
|
||||||
|
"your can run `dbgpt db migration --help` to get more information."
|
||||||
|
)
|
||||||
|
logger.warning(warn_msg)
|
||||||
|
from dbgpt.storage.metadata.db_manager import db
|
||||||
|
|
||||||
|
alembic_cfg = create_alembic_config(
|
||||||
|
default_meta_data_path,
|
||||||
|
db.engine,
|
||||||
|
db.Model,
|
||||||
|
db.session(),
|
||||||
|
alembic_ini_path,
|
||||||
|
script_location,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
create_migration_script(alembic_cfg, db.engine)
|
||||||
|
upgrade_database(alembic_cfg, db.engine)
|
||||||
|
except CommandError as e:
|
||||||
|
if "Target database is not up to date" in str(e):
|
||||||
|
logger.error(
|
||||||
|
f"Initialize and upgrade database metadata with alembic failed, error detail: {str(e)} "
|
||||||
|
f"you can try the following solutions:\n{_MIGRATION_SOLUTION}\n"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
"Initialize and upgrade database metadata with alembic failed, "
|
||||||
|
"you can see the error and solutions above"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
@@ -39,6 +39,31 @@ def PublicAPI(*args, **kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def DeveloperAPI(*args, **kwargs):
|
||||||
|
"""Decorator to mark a function or class as a developer API.
|
||||||
|
|
||||||
|
Developer APIs are low-level APIs for advanced users and may change cross major versions.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from dbgpt.util.annotations import DeveloperAPI
|
||||||
|
>>> @DeveloperAPI
|
||||||
|
... def foo():
|
||||||
|
... pass
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||||
|
return DeveloperAPI()(args[0])
|
||||||
|
|
||||||
|
def decorator(obj):
|
||||||
|
_modify_docstring(
|
||||||
|
obj,
|
||||||
|
"**DeveloperAPI:** This API is for advanced users and may change cross major versions.",
|
||||||
|
)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _modify_docstring(obj, message: str = None):
|
def _modify_docstring(obj, message: str = None):
|
||||||
if not message:
|
if not message:
|
||||||
return
|
return
|
||||||
|
14
dbgpt/util/pagination_utils.py
Normal file
14
dbgpt/util/pagination_utils.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import TypeVar, Generic, List
|
||||||
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationResult(BaseModel, Generic[T]):
|
||||||
|
"""Pagination result"""
|
||||||
|
|
||||||
|
items: List[T] = Field(..., description="The items in the current page")
|
||||||
|
total_count: int = Field(..., description="Total number of items")
|
||||||
|
total_pages: int = Field(..., description="total number of pages")
|
||||||
|
page: int = Field(..., description="Current page number")
|
||||||
|
page_size: int = Field(..., description="Number of items per page")
|
@@ -41,4 +41,6 @@ class JsonSerializer(Serializer):
|
|||||||
# Convert bytes back to JSON and then to the specified class
|
# Convert bytes back to JSON and then to the specified class
|
||||||
json_data = json.loads(data.decode(JSON_ENCODING))
|
json_data = json.loads(data.decode(JSON_ENCODING))
|
||||||
# Assume that the cls has an __init__ that accepts a dictionary
|
# Assume that the cls has an __init__ that accepts a dictionary
|
||||||
return cls(**json_data)
|
obj = cls(**json_data)
|
||||||
|
obj.set_serializer(self)
|
||||||
|
return obj
|
||||||
|
@@ -73,9 +73,11 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
|
|||||||
return match_map
|
return match_map
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def _to_str(x, charset="utf8", errors="strict"):
|
||||||
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
|
if x is None or isinstance(x, str):
|
||||||
s1 = "123"
|
return x
|
||||||
s2 = "456"
|
|
||||||
|
|
||||||
print(extract_content_open_ending(s, s1, s2, True))
|
if isinstance(x, bytes):
|
||||||
|
return x.decode(charset, errors)
|
||||||
|
|
||||||
|
return str(x)
|
||||||
|
@@ -71,3 +71,66 @@ Download and install `Microsoft C++ Build Tools` from [visual-cpp-build-tools](h
|
|||||||
1. update your mysql username and password in docker/examples/metadata/duckdb2mysql.py
|
1. update your mysql username and password in docker/examples/metadata/duckdb2mysql.py
|
||||||
2. python docker/examples/metadata/duckdb2mysql.py
|
2. python docker/examples/metadata/duckdb2mysql.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
##### Q8: `How to manage and migrate my database`
|
||||||
|
|
||||||
|
You can use the command of `dbgpt db migration` to manage and migrate your database.
|
||||||
|
|
||||||
|
See the following command for details.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration --help
|
||||||
|
```
|
||||||
|
|
||||||
|
First, you need to create a migration script(just once unless you clean it).
|
||||||
|
This command with create a `alembic` directory in your `pilot/meta_data` directory and a initial migration script in it.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration init
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can upgrade your database with the following command.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
Every time you change the model or pull the latest code from DB-GPT repository, you need to create a new migration script.
|
||||||
|
```commandline
|
||||||
|
|
||||||
|
dbgpt db migration migrate -m "your message"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can upgrade your database with the following command.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Q9: `alembic.util.exc.CommandError: Target database is not up to date.`
|
||||||
|
|
||||||
|
**Solution 1:**
|
||||||
|
|
||||||
|
Run the following command to upgrade the database.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution 2:**
|
||||||
|
|
||||||
|
Run the following command to clean the migration script and migration history.
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration clean -y
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution 3:**
|
||||||
|
|
||||||
|
If you have already run the above command, but the error still exists,
|
||||||
|
you can try the following command to clean the migration script, migration history and your data.
|
||||||
|
warning: This command will delete all your data!!! Please use it with caution.
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
dbgpt db migration clean --drop_all_tables -y --confirm_drop_all_tables
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```commandline
|
||||||
|
rm -rf pilot/meta_data/alembic/versions/*
|
||||||
|
rm -rf pilot/meta_data/alembic/dbgpt.db
|
||||||
|
```
|
@@ -97,6 +97,8 @@ Configure the proxy and modify LLM_MODEL, PROXY_API_URL and API_KEY in the `.env
|
|||||||
LLM_MODEL=chatgpt_proxyllm
|
LLM_MODEL=chatgpt_proxyllm
|
||||||
PROXY_API_KEY={your-openai-sk}
|
PROXY_API_KEY={your-openai-sk}
|
||||||
PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
||||||
|
# If you use gpt-4
|
||||||
|
# PROXYLLM_BACKEND=gpt-4
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="qwen" label="通义千问">
|
<TabItem value="qwen" label="通义千问">
|
||||||
|
@@ -3,7 +3,7 @@ from sqlalchemy import pool
|
|||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
|
|
||||||
from dbgpt.storage.metadata.meta_data import Base, engine
|
from dbgpt.storage.metadata.db_manager import db
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
@@ -13,8 +13,7 @@ config = context.config
|
|||||||
# add your model's MetaData object here
|
# add your model's MetaData object here
|
||||||
# for 'autogenerate' support
|
# for 'autogenerate' support
|
||||||
# from myapp import mymodel
|
# from myapp import mymodel
|
||||||
# target_metadata = mymodel.Base.metadata
|
|
||||||
target_metadata = Base.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
# can be acquired:
|
# can be acquired:
|
||||||
@@ -34,6 +33,8 @@ def run_migrations_offline() -> None:
|
|||||||
script output.
|
script output.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
engine = db.engine
|
||||||
|
target_metadata = db.metadata
|
||||||
url = config.get_main_option(engine.url)
|
url = config.get_main_option(engine.url)
|
||||||
context.configure(
|
context.configure(
|
||||||
url=url,
|
url=url,
|
||||||
@@ -53,12 +54,8 @@ def run_migrations_online() -> None:
|
|||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
connectable = engine_from_config(
|
engine = db.engine
|
||||||
config.get_section(config.config_ini_section, {}),
|
target_metadata = db.metadata
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
with engine.connect() as connection:
|
||||||
if engine.dialect.name == "sqlite":
|
if engine.dialect.name == "sqlite":
|
||||||
context.configure(
|
context.configure(
|
||||||
|
Reference in New Issue
Block a user