mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 19:31:43 +00:00
feat(rag): Support rag retriever evaluation (#1291)
This commit is contained in:
parent
cd2dcc253c
commit
adaa68eb00
@ -278,6 +278,7 @@ CREATE TABLE `dbgpt_serve_flow` (
|
||||
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
|
||||
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
|
||||
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
|
||||
`error_message` varchar(512) NULL comment 'Error message',
|
||||
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
|
||||
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
|
||||
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
|
||||
|
0
assets/schema/upgrade/v0_5_2/upgrade_to_v0.5.2.sql
Normal file
0
assets/schema/upgrade/v0_5_2/upgrade_to_v0.5.2.sql
Normal file
395
assets/schema/upgrade/v0_5_2/v0.5.1.sql
Normal file
395
assets/schema/upgrade/v0_5_2/v0.5.1.sql
Normal file
@ -0,0 +1,395 @@
|
||||
-- Full SQL of v0.5.1, please not modify this file(It must be same as the file in the release package)
|
||||
|
||||
CREATE
|
||||
DATABASE IF NOT EXISTS dbgpt;
|
||||
use dbgpt;
|
||||
|
||||
-- For alembic migration tool
|
||||
CREATE TABLE IF NOT EXISTS `alembic_version`
|
||||
(
|
||||
version_num VARCHAR(32) NOT NULL,
|
||||
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
|
||||
) DEFAULT CHARSET=utf8mb4 ;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `knowledge_space`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
|
||||
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
|
||||
`desc` varchar(500) NOT NULL COMMENT 'description',
|
||||
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
|
||||
`context` TEXT DEFAULT NULL COMMENT 'context argument',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `knowledge_document`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
|
||||
`chunk_size` int NOT NULL COMMENT 'chunk size',
|
||||
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
|
||||
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
|
||||
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
||||
`result` TEXT NULL COMMENT 'knowledge content',
|
||||
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
||||
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `document_chunk`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||
`document_id` int NOT NULL COMMENT 'document parent id',
|
||||
`content` longtext NOT NULL COMMENT 'chunk content',
|
||||
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail';
|
||||
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `connect_config`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`db_type` varchar(255) NOT NULL COMMENT 'db type',
|
||||
`db_name` varchar(255) NOT NULL COMMENT 'db name',
|
||||
`db_path` varchar(255) DEFAULT NULL COMMENT 'file db path',
|
||||
`db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)',
|
||||
`db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)',
|
||||
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
|
||||
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
|
||||
`comment` text COMMENT 'db comment',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_db` (`db_name`),
|
||||
KEY `idx_q_db_type` (`db_type`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `chat_history`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
|
||||
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
||||
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
||||
`message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
UNIQUE KEY `conv_uid` (`conv_uid`),
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `chat_history_message`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||
`index` int NOT NULL COMMENT 'Message index',
|
||||
`round_index` int NOT NULL COMMENT 'Round of conversation',
|
||||
`message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
UNIQUE KEY `message_uid_index` (`conv_uid`, `index`),
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `chat_feed_back`
|
||||
(
|
||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID',
|
||||
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation',
|
||||
`score` int(1) DEFAULT NULL COMMENT 'Score of user',
|
||||
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category',
|
||||
`question` longtext DEFAULT NULL COMMENT 'User question',
|
||||
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
|
||||
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
||||
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table';
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `my_plugin`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant',
|
||||
`user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code',
|
||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name',
|
||||
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
||||
`file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name',
|
||||
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
|
||||
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `name` (`name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `plugin_hub`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
||||
`description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description',
|
||||
`author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author',
|
||||
`email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email',
|
||||
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||
`storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel',
|
||||
`storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url',
|
||||
`download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time',
|
||||
`installed` int DEFAULT NULL COMMENT 'plugin already installed count',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `name` (`name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table';
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `prompt_manage`
|
||||
(
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
|
||||
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
|
||||
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
|
||||
`prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
||||
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
||||
`input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))',
|
||||
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
|
||||
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
|
||||
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
|
||||
`prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description',
|
||||
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`),
|
||||
KEY `gmt_created_idx` (`gmt_created`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `gpts_conversations` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
|
||||
`user_goal` text NOT NULL COMMENT 'User''s goals content',
|
||||
`gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name',
|
||||
`state` varchar(255) DEFAULT NULL COMMENT 'The gpts state',
|
||||
`max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round',
|
||||
`auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count',
|
||||
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
`team_mode` varchar(255) NULL COMMENT 'agent team work mode',
|
||||
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts_conversations` (`conv_id`),
|
||||
KEY `idx_gpts_name` (`gpts_name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations";
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `gpts_instance` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
|
||||
`gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe',
|
||||
`resource_db` text COMMENT 'List of structured database names contained in the current gpts',
|
||||
`resource_internet` text COMMENT 'Is it possible to retrieve information from the internet',
|
||||
`resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts',
|
||||
`gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts',
|
||||
`gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts',
|
||||
`language` varchar(100) DEFAULT NULL COMMENT 'gpts language',
|
||||
`user_code` varchar(255) NOT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
`team_mode` varchar(255) NOT NULL COMMENT 'Team work mode',
|
||||
`is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts` (`gpts_name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance";
|
||||
|
||||
CREATE TABLE `gpts_messages` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
|
||||
`sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn',
|
||||
`receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn',
|
||||
`model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model',
|
||||
`rounds` int(11) NOT NULL COMMENT 'dialogue turns',
|
||||
`content` text COMMENT 'Content of the speech',
|
||||
`current_goal` text COMMENT 'The target corresponding to the current message',
|
||||
`context` text COMMENT 'Current conversation context',
|
||||
`review_info` text COMMENT 'Current conversation review info',
|
||||
`action_report` text COMMENT 'Current conversation action report',
|
||||
`role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message";
|
||||
|
||||
|
||||
CREATE TABLE `gpts_plans` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
|
||||
`sub_task_num` int(11) NOT NULL COMMENT 'Subtask number',
|
||||
`sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title',
|
||||
`sub_task_content` text NOT NULL COMMENT 'subtask content',
|
||||
`sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks',
|
||||
`resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name',
|
||||
`rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3',
|
||||
`agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents',
|
||||
`retry_times` int(11) DEFAULT NULL COMMENT 'number of retries',
|
||||
`max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries',
|
||||
`state` varchar(255) DEFAULT NULL COMMENT 'subtask status',
|
||||
`result` longtext COMMENT 'subtask result',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan";
|
||||
|
||||
-- dbgpt.dbgpt_serve_flow definition
|
||||
CREATE TABLE `dbgpt_serve_flow` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id',
|
||||
`uid` varchar(128) NOT NULL COMMENT 'Unique id',
|
||||
`dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id',
|
||||
`name` varchar(128) DEFAULT NULL COMMENT 'Flow name',
|
||||
`flow_data` text COMMENT 'Flow data, JSON format',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time',
|
||||
`gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time',
|
||||
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
|
||||
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
|
||||
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
|
||||
`error_message` varchar(512) NULL comment 'Error message',
|
||||
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
|
||||
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
|
||||
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
|
||||
`label` varchar(128) DEFAULT NULL COMMENT 'Flow label',
|
||||
`editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_uid` (`uid`),
|
||||
KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`),
|
||||
KEY `ix_dbgpt_serve_flow_uid` (`uid`),
|
||||
KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`),
|
||||
KEY `ix_dbgpt_serve_flow_user_name` (`user_name`),
|
||||
KEY `ix_dbgpt_serve_flow_name` (`name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
-- dbgpt.gpts_app definition
|
||||
CREATE TABLE `gpts_app` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
|
||||
`app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
|
||||
`app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe',
|
||||
`language` varchar(100) NOT NULL COMMENT 'gpts language',
|
||||
`team_mode` varchar(255) NOT NULL COMMENT 'Team work mode',
|
||||
`team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on',
|
||||
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
`icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts_app` (`app_name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE `gpts_app_collection` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
|
||||
`user_code` int(11) NOT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) NOT NULL COMMENT 'system app code',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_app_code` (`app_code`),
|
||||
KEY `idx_user_code` (`user_code`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections";
|
||||
|
||||
-- dbgpt.gpts_app_detail definition
|
||||
CREATE TABLE `gpts_app_detail` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
|
||||
`app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
|
||||
`agent_name` varchar(255) NOT NULL COMMENT ' Agent name',
|
||||
`node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id',
|
||||
`resources` text COMMENT 'Agent bind resource',
|
||||
`prompt_template` text COMMENT 'Agent bind template',
|
||||
`llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy',
|
||||
`llm_strategy_value` text COMMENT 'Agent use llm strategy value',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE
|
||||
DATABASE IF NOT EXISTS EXAMPLE_1;
|
||||
use EXAMPLE_1;
|
||||
CREATE TABLE IF NOT EXISTS `users`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`username` varchar(50) NOT NULL COMMENT '用户名',
|
||||
`password` varchar(50) NOT NULL COMMENT '密码',
|
||||
`email` varchar(50) NOT NULL COMMENT '邮箱',
|
||||
`phone` varchar(20) DEFAULT NULL COMMENT '电话',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_username` (`username`) COMMENT '索引:按用户名查询'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表';
|
||||
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900');
|
@ -167,6 +167,8 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
# https://huggingface.co/BAAI/bge-large-zh
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
# https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True
|
||||
"bge-m3": os.path.join(MODEL_PATH, "bge-m3"),
|
||||
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
|
||||
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
|
@ -7,6 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401
|
||||
CachePolicy,
|
||||
CacheValue,
|
||||
)
|
||||
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
|
||||
from dbgpt.core.interface.llm import ( # noqa: F401
|
||||
DefaultMessageConverter,
|
||||
LLMClient,
|
||||
@ -103,4 +104,5 @@ __ALL__ = [
|
||||
"DefaultStorageItemAdapter",
|
||||
"QuerySpec",
|
||||
"StorageError",
|
||||
"Embeddings",
|
||||
]
|
||||
|
@ -55,6 +55,7 @@ from .trigger.http_trigger import (
|
||||
CommonLLMHttpResponseBody,
|
||||
HttpTrigger,
|
||||
)
|
||||
from .trigger.iterator_trigger import IteratorTrigger
|
||||
|
||||
_request_http_trigger_available = False
|
||||
try:
|
||||
@ -100,6 +101,7 @@ __all__ = [
|
||||
"TransformStreamAbsOperator",
|
||||
"Trigger",
|
||||
"HttpTrigger",
|
||||
"IteratorTrigger",
|
||||
"CommonLLMHTTPRequestContext",
|
||||
"CommonLLMHttpResponseBody",
|
||||
"CommonLLMHttpRequestBody",
|
||||
|
@ -277,7 +277,7 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
return task_output
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
|
@ -60,8 +60,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
streaming_call=streaming_call,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
if node.dag:
|
||||
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
# if node.dag:
|
||||
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
|
||||
)
|
||||
@ -76,8 +76,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
if node.dag:
|
||||
del self._running_dag_ctx[node.dag.dag_id]
|
||||
# if node.dag:
|
||||
# del self._running_dag_ctx[node.dag.dag_id]
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
|
@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: T) -> "InputSource[T]":
|
||||
"""Create an InputSource from data.
|
||||
|
||||
Args:
|
||||
data (T): The data to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the data.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(data, streaming=False)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
|
||||
) -> "InputSource[T]":
|
||||
"""Create an InputSource from an iterable.
|
||||
|
||||
Args:
|
||||
iterable (List[T]): The iterable to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the iterable.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(iterable, streaming=True)
|
||||
|
||||
@classmethod
|
||||
def from_callable(cls) -> "InputSource[T]":
|
||||
"""Create an InputSource from a callable."""
|
||||
from .task_impl import SimpleCallDataInputSource
|
||||
|
||||
return SimpleCallDataInputSource()
|
||||
|
@ -261,13 +261,42 @@ def _is_async_iterator(obj):
|
||||
)
|
||||
|
||||
|
||||
def _is_async_iterable(obj):
|
||||
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
|
||||
|
||||
|
||||
def _is_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__iter__")
|
||||
and callable(getattr(obj, "__iter__", None))
|
||||
and hasattr(obj, "__next__")
|
||||
and callable(getattr(obj, "__next__", None))
|
||||
)
|
||||
|
||||
|
||||
def _is_iterable(obj):
|
||||
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
|
||||
|
||||
|
||||
async def _to_async_iterator(obj) -> AsyncIterator:
|
||||
if _is_async_iterable(obj):
|
||||
async for item in obj:
|
||||
yield item
|
||||
elif _is_iterable(obj):
|
||||
for item in obj:
|
||||
yield item
|
||||
else:
|
||||
raise ValueError(f"Can not convert {obj} to AsyncIterator")
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
self._streaming_data = streaming
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._streaming_data is None:
|
||||
streaming_data = _is_async_iterator(data) or _is_iterator(data)
|
||||
else:
|
||||
streaming_data = self._streaming_data
|
||||
if streaming_data:
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
it_data = _to_async_iterator(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(it_data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(streaming=streaming)
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
|
0
dbgpt/core/awel/tests/trigger/__init__.py
Normal file
0
dbgpt/core/awel/tests/trigger/__init__.py
Normal file
118
dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Normal file
118
dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Normal file
@ -0,0 +1,118 @@
|
||||
from typing import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
InputSource,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger
|
||||
|
||||
|
||||
class NumberProducerOperator(StreamifyAbsOperator[int, int]):
|
||||
"""Create a stream of numbers from 0 to `n-1`"""
|
||||
|
||||
async def streamify(self, n: int) -> AsyncIterator[int]:
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
|
||||
class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||
async for i in data:
|
||||
yield i * i
|
||||
|
||||
|
||||
async def _check_stream_results(stream_results, expected_len):
|
||||
assert len(stream_results) == expected_len
|
||||
for _, result in stream_results:
|
||||
i = 0
|
||||
async for num in result:
|
||||
assert num == i * i
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_data():
|
||||
with DAG("test_single_data"):
|
||||
trigger_task = IteratorTrigger(data=2)
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 1
|
||||
assert results[0][1] == 4
|
||||
|
||||
with DAG("test_single_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=2, streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_data():
|
||||
with DAG("test_list_data"):
|
||||
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_list_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iterator_data():
|
||||
async def async_iter():
|
||||
for i in range(4):
|
||||
yield i
|
||||
|
||||
with DAG("test_async_iterator_data"):
|
||||
trigger_task = IteratorTrigger(data=async_iter())
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_async_iterator_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_source_data():
|
||||
with DAG("test_input_source_data"):
|
||||
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_input_source_data_stream"):
|
||||
trigger_task = IteratorTrigger(
|
||||
data=InputSource.from_iterable([0, 1, 2, 3]),
|
||||
streaming_call=True,
|
||||
)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
@ -2,16 +2,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic
|
||||
|
||||
from ..operators.common_operator import TriggerOperator
|
||||
from ..task.base import OUT
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
|
||||
"""Base class for all trigger classes.
|
||||
|
||||
Now only support http trigger.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> Any:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@ -397,9 +397,9 @@ class HttpTrigger(Trigger):
|
||||
self._end_node: Optional[BaseOperator] = None
|
||||
self._register_to_app = register_to_app
|
||||
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> Any:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
pass
|
||||
raise NotImplementedError("HttpTrigger does not support trigger directly")
|
||||
|
||||
def register_to_app(self) -> bool:
|
||||
"""Register the trigger to a FastAPI app.
|
||||
|
148
dbgpt/core/awel/trigger/iterator_trigger.py
Normal file
148
dbgpt/core/awel/trigger/iterator_trigger.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Trigger for iterator data."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
from ..operators.base import BaseOperator
|
||||
from ..task.base import InputSource, TaskState
|
||||
from ..task.task_impl import DefaultTaskContext, _is_async_iterator, _is_iterable
|
||||
from .base import Trigger
|
||||
|
||||
IterDataType = Union[InputSource, Iterator, AsyncIterator, Any]
|
||||
|
||||
|
||||
async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIterator:
|
||||
"""Convert iter_data to an async iterator."""
|
||||
if _is_async_iterator(iter_data):
|
||||
async for item in iter_data: # type: ignore
|
||||
yield item
|
||||
elif _is_iterable(iter_data):
|
||||
for item in iter_data: # type: ignore
|
||||
yield item
|
||||
elif isinstance(iter_data, InputSource):
|
||||
task_ctx: DefaultTaskContext[Any] = DefaultTaskContext(
|
||||
task_id, TaskState.RUNNING, None
|
||||
)
|
||||
data = await iter_data.read(task_ctx)
|
||||
if data.is_stream:
|
||||
async for item in data.output_stream:
|
||||
yield item
|
||||
else:
|
||||
yield data.output
|
||||
else:
|
||||
yield iter_data
|
||||
|
||||
|
||||
class IteratorTrigger(Trigger):
|
||||
"""Trigger for iterator data.
|
||||
|
||||
Trigger the dag with iterator data.
|
||||
Return the list of results of the leaf nodes in the dag.
|
||||
The times of dag running is the length of the iterator data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: IterDataType,
|
||||
parallel_num: int = 1,
|
||||
streaming_call: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a IteratorTrigger.
|
||||
|
||||
Args:
|
||||
data (IterDataType): The iterator data.
|
||||
parallel_num (int, optional): The parallel number of the dag running.
|
||||
Defaults to 1.
|
||||
streaming_call (bool, optional): Whether the dag is a streaming call.
|
||||
Defaults to False.
|
||||
"""
|
||||
self._iter_data = data
|
||||
self._parallel_num = parallel_num
|
||||
self._streaming_call = streaming_call
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def trigger(
|
||||
self, parallel_num: Optional[int] = None, **kwargs
|
||||
) -> List[Tuple[Any, Any]]:
|
||||
"""Trigger the dag with iterator data.
|
||||
|
||||
If the dag is a streaming call, return the list of async iterator.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
|
||||
|
||||
with DAG("test_dag") as dag:
|
||||
trigger_task = IteratorTrigger([0, 1, 2, 3])
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = asyncio.run(trigger_task.trigger())
|
||||
# Fist element of the tuple is the input data, the second element is the
|
||||
# output data of the leaf node.
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from datasets import Dataset
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
IteratorTrigger,
|
||||
MapOperator,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
data_samples = {
|
||||
"question": ["What is 1+1?", "What is 7*7?"],
|
||||
"answer": [2, 49],
|
||||
}
|
||||
dataset = Dataset.from_dict(data_samples)
|
||||
with DAG("test_dag_stream") as dag:
|
||||
trigger_task = IteratorTrigger(InputSource.from_iterable(dataset))
|
||||
task = MapOperator(lambda x: x["answer"])
|
||||
trigger_task >> task
|
||||
results = asyncio.run(trigger_task.trigger())
|
||||
assert results == [
|
||||
({"question": "What is 1+1?", "answer": 2}, 2),
|
||||
({"question": "What is 7*7?", "answer": 49}, 49),
|
||||
]
|
||||
Args:
|
||||
parallel_num (Optional[int], optional): The parallel number of the dag
|
||||
running. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Any, Any]]: The list of results of the leaf nodes in the dag.
|
||||
The first element of the tuple is the input data, the second element is
|
||||
the output data of the leaf node.
|
||||
"""
|
||||
dag = self.dag
|
||||
if not dag:
|
||||
raise ValueError("DAG is not set for IteratorTrigger")
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("IteratorTrigger just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
streaming_call = self._streaming_call
|
||||
semaphore = asyncio.Semaphore(parallel_num or self._parallel_num)
|
||||
task_id = self.node_id
|
||||
|
||||
async def call_stream(call_data: Any):
|
||||
async for out in await end_node.call_stream(call_data):
|
||||
yield out
|
||||
|
||||
async def run_node(call_data: Any):
|
||||
async with semaphore:
|
||||
if streaming_call:
|
||||
task_output = call_stream(call_data)
|
||||
else:
|
||||
task_output = await end_node.call(call_data)
|
||||
return call_data, task_output
|
||||
|
||||
tasks = []
|
||||
async for data in _to_async_iterator(self._iter_data, task_id):
|
||||
tasks.append(run_node(data))
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
32
dbgpt/core/interface/embeddings.py
Normal file
32
dbgpt/core/interface/embeddings.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Interface for embedding models."""
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models.
|
||||
|
||||
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
|
||||
master/libs/langchain/langchain/embeddings>`_.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
253
dbgpt/core/interface/evaluation.py
Normal file
253
dbgpt/core/interface/evaluation.py
Normal file
@ -0,0 +1,253 @@
|
||||
"""Evaluation module."""
|
||||
import asyncio
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.util.similarity_util import calculate_cosine_similarity
|
||||
|
||||
from .embeddings import Embeddings
|
||||
from .llm import LLMClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.awel.task.base import InputSource
|
||||
|
||||
QueryType = Union[str, Any]
|
||||
PredictionType = Union[str, Any]
|
||||
ContextType = Union[str, Sequence[str], Any]
|
||||
DatasetType = Union["InputSource", Iterator, AsyncIterator]
|
||||
|
||||
|
||||
class BaseEvaluationResult(BaseModel):
|
||||
"""Base evaluation result."""
|
||||
|
||||
prediction: Optional[PredictionType] = Field(
|
||||
None,
|
||||
description="Prediction data(including the output of LLM, the data from "
|
||||
"retrieval, etc.)",
|
||||
)
|
||||
contexts: Optional[ContextType] = Field(None, description="Context data")
|
||||
score: Optional[float] = Field(None, description="Score for the prediction")
|
||||
passing: Optional[bool] = Field(
|
||||
None, description="Binary evaluation result (passing or not)"
|
||||
)
|
||||
metric_name: Optional[str] = Field(None, description="Name of the metric")
|
||||
|
||||
|
||||
class EvaluationResult(BaseEvaluationResult):
|
||||
"""Evaluation result.
|
||||
|
||||
Output of an BaseEvaluator.
|
||||
"""
|
||||
|
||||
query: Optional[QueryType] = Field(None, description="Query data")
|
||||
raw_dataset: Optional[Any] = Field(None, description="Raw dataset")
|
||||
|
||||
|
||||
Q = TypeVar("Q")
|
||||
P = TypeVar("P")
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class EvaluationMetric(ABC, Generic[P, C]):
|
||||
"""Base class for evaluation metric."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the metric."""
|
||||
return self.__class__.__name__
|
||||
|
||||
async def compute(
|
||||
self,
|
||||
prediction: P,
|
||||
contexts: Optional[Sequence[C]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric.
|
||||
|
||||
Args:
|
||||
prediction(P): The prediction data.
|
||||
contexts(Optional[Sequence[C]]): The context data.
|
||||
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.sync_compute, prediction, contexts
|
||||
)
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: P,
|
||||
contexts: Optional[Sequence[C]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric.
|
||||
|
||||
Args:
|
||||
prediction(P): The prediction data.
|
||||
contexts(Optional[Sequence[C]]): The context data.
|
||||
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
"""
|
||||
raise NotImplementedError("sync_compute is not implemented")
|
||||
|
||||
|
||||
class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
|
||||
"""Evaluation metric based on a function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable[
|
||||
[P, Optional[Sequence[C]]],
|
||||
BaseEvaluationResult,
|
||||
],
|
||||
):
|
||||
"""Create a FunctionMetric.
|
||||
|
||||
Args:
|
||||
name(str): The name of the metric.
|
||||
func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]):
|
||||
The function to use for evaluation.
|
||||
"""
|
||||
self._name = name
|
||||
self.func = func
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the metric."""
|
||||
return self._name
|
||||
|
||||
async def compute(
|
||||
self,
|
||||
prediction: P,
|
||||
context: Optional[Sequence[C]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric."""
|
||||
return self.func(prediction, context)
|
||||
|
||||
|
||||
class ExactMatchMetric(EvaluationMetric[str, str]):
|
||||
"""Exact match metric.
|
||||
|
||||
Just support string prediction and context.
|
||||
"""
|
||||
|
||||
def __init__(self, ignore_case: bool = False, ignore_punctuation: bool = False):
|
||||
"""Create an ExactMatchMetric."""
|
||||
self._ignore_case = ignore_case
|
||||
self._ignore_punctuation = ignore_punctuation
|
||||
|
||||
async def compute(
|
||||
self,
|
||||
prediction: str,
|
||||
contexts: Optional[Sequence[str]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric."""
|
||||
if self._ignore_case:
|
||||
prediction = prediction.lower()
|
||||
if contexts:
|
||||
contexts = [c.lower() for c in contexts]
|
||||
if self._ignore_punctuation:
|
||||
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
|
||||
if contexts:
|
||||
contexts = [
|
||||
c.translate(str.maketrans("", "", string.punctuation))
|
||||
for c in contexts
|
||||
]
|
||||
score = 0 if not contexts else float(prediction in contexts)
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=score,
|
||||
)
|
||||
|
||||
|
||||
class SimilarityMetric(EvaluationMetric[str, str]):
|
||||
"""Similarity metric.
|
||||
|
||||
Calculate the cosine similarity between a prediction and a list of contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
"""Create a SimilarityMetric with embeddings."""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: str,
|
||||
contexts: Optional[Sequence[str]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric."""
|
||||
if not contexts:
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=0.0,
|
||||
)
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("numpy is required for SimilarityMetric")
|
||||
|
||||
similarity: np.ndarray = calculate_cosine_similarity(
|
||||
self._embeddings, prediction, contexts
|
||||
)
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=float(similarity.mean()),
|
||||
)
|
||||
|
||||
|
||||
class Evaluator(ABC):
|
||||
"""Base Evaluator class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
):
|
||||
"""Create an Evaluator."""
|
||||
self.llm_client = llm_client
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
metrics: Optional[List[EvaluationMetric]] = None,
|
||||
query_key: str = "query",
|
||||
contexts_key: str = "contexts",
|
||||
prediction_key: str = "prediction",
|
||||
parallel_num: int = 1,
|
||||
**kwargs
|
||||
) -> List[List[EvaluationResult]]:
|
||||
"""Run evaluation with a dataset and metrics.
|
||||
|
||||
Args:
|
||||
dataset(DatasetType): The dataset to evaluate.
|
||||
metrics(Optional[List[EvaluationMetric]]): The metrics to use for
|
||||
evaluation.
|
||||
query_key(str): The key for query in the dataset.
|
||||
contexts_key(str): The key for contexts in the dataset.
|
||||
prediction_key(str): The key for prediction in the dataset.
|
||||
parallel_num(int): The number of parallel tasks.
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
List[List[EvaluationResult]]: The evaluation results, the length of the
|
||||
result equals to the length of the dataset. The first element in the
|
||||
list is the list of evaluation results for metrics.
|
||||
"""
|
@ -5,9 +5,10 @@ from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.chunk import Chunk, Document
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
|
||||
from dbgpt.rag.text_splitter import TextSplitter
|
||||
|
||||
|
||||
class SplitterType(Enum):
|
||||
@ -81,14 +82,14 @@ class ChunkManager:
|
||||
self._text_splitter = self._chunk_parameters.text_splitter
|
||||
self._splitter_type = self._chunk_parameters.splitter_type
|
||||
|
||||
def split(self, documents) -> List[Chunk]:
|
||||
def split(self, documents: List[Document]) -> List[Chunk]:
|
||||
"""Split a document into chunks."""
|
||||
text_splitter = self._select_text_splitter()
|
||||
if SplitterType.LANGCHAIN == self._splitter_type:
|
||||
documents = text_splitter.split_documents(documents)
|
||||
return [Chunk.langchain2chunk(document) for document in documents]
|
||||
elif SplitterType.LLAMA_INDEX == self._splitter_type:
|
||||
nodes = text_splitter.split_text(documents)
|
||||
nodes = text_splitter.split_documents(documents)
|
||||
return [Chunk.llamaindex2chunk(node) for node in nodes]
|
||||
else:
|
||||
return text_splitter.split_documents(documents)
|
||||
@ -106,7 +107,7 @@ class ChunkManager:
|
||||
|
||||
def set_text_splitter(
|
||||
self,
|
||||
text_splitter,
|
||||
text_splitter: TextSplitter,
|
||||
splitter_type: SplitterType = SplitterType.LANGCHAIN,
|
||||
) -> None:
|
||||
"""Add text splitter."""
|
||||
@ -115,13 +116,13 @@ class ChunkManager:
|
||||
|
||||
def get_text_splitter(
|
||||
self,
|
||||
) -> Any:
|
||||
) -> TextSplitter:
|
||||
"""Return text splitter."""
|
||||
return self._select_text_splitter()
|
||||
|
||||
def _select_text_splitter(
|
||||
self,
|
||||
):
|
||||
) -> TextSplitter:
|
||||
"""Select text splitter by chunk strategy."""
|
||||
if self._text_splitter:
|
||||
return self._text_splitter
|
||||
|
@ -1,13 +1,13 @@
|
||||
"""Embedding implementations."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Extra, Field
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
@ -22,34 +22,6 @@ DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
|
||||
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models.
|
||||
|
||||
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
|
||||
master/libs/langchain/langchain/embeddings>`_.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace sentence_transformers embedding models.
|
||||
|
||||
|
13
dbgpt/rag/evaluation/__init__.py
Normal file
13
dbgpt/rag/evaluation/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""Module for evaluation of RAG."""
|
||||
|
||||
from .retriever import ( # noqa: F401
|
||||
RetrieverEvaluationMetric,
|
||||
RetrieverEvaluator,
|
||||
RetrieverSimilarityMetric,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"RetrieverEvaluator",
|
||||
"RetrieverSimilarityMetric",
|
||||
"RetrieverEvaluationMetric",
|
||||
]
|
171
dbgpt/rag/evaluation/retriever.py
Normal file
171
dbgpt/rag/evaluation/retriever.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Evaluation for retriever."""
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Sequence, Type
|
||||
|
||||
from dbgpt.core import Embeddings, LLMClient
|
||||
from dbgpt.core.interface.evaluation import (
|
||||
BaseEvaluationResult,
|
||||
DatasetType,
|
||||
EvaluationMetric,
|
||||
EvaluationResult,
|
||||
Evaluator,
|
||||
)
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.util.similarity_util import calculate_cosine_similarity
|
||||
|
||||
from ..operators.evaluation import RetrieverEvaluatorOperator
|
||||
|
||||
|
||||
class RetrieverEvaluationMetric(EvaluationMetric[List[str], str], ABC):
|
||||
"""Evaluation metric for retriever.
|
||||
|
||||
The prediction is a list of str(content from chunks) and the context is a string.
|
||||
"""
|
||||
|
||||
|
||||
class RetrieverSimilarityMetric(RetrieverEvaluationMetric):
|
||||
"""Similarity metric for retriever."""
|
||||
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
"""Create a SimilarityMetric with embeddings."""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: List[str],
|
||||
contexts: Optional[Sequence[str]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric.
|
||||
|
||||
Args:
|
||||
prediction(List[str]): The retrieved chunks from the retriever.
|
||||
contexts(Sequence[str]): The contexts from dataset.
|
||||
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
The score is the mean of the cosine similarity between the prediction
|
||||
and the contexts.
|
||||
"""
|
||||
if not prediction or not contexts:
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=0.0,
|
||||
)
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("numpy is required for RelevancySimilarityMetric")
|
||||
|
||||
similarity: np.ndarray = calculate_cosine_similarity(
|
||||
self._embeddings, contexts[0], prediction
|
||||
)
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=float(similarity.mean()),
|
||||
)
|
||||
|
||||
|
||||
class RetrieverEvaluator(Evaluator):
|
||||
"""Evaluator for relevancy.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from dbgpt.rag.operators import (
|
||||
EmbeddingRetrieverOperator,
|
||||
RetrieverEvaluatorOperator,
|
||||
)
|
||||
from dbgpt.rag.evaluation import (
|
||||
RetrieverEvaluator,
|
||||
RetrieverSimilarityMetric,
|
||||
)
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create()
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="my_test_schema",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
|
||||
dataset = [
|
||||
{
|
||||
"query": "what is awel talk about",
|
||||
"contexts": [
|
||||
"Through the AWEL API, you can focus on the development"
|
||||
" of business logic for LLMs applications without paying "
|
||||
"attention to cumbersome model and environment details."
|
||||
],
|
||||
},
|
||||
]
|
||||
evaluator = RetrieverEvaluator(
|
||||
operator_cls=EmbeddingRetrieverOperator,
|
||||
embeddings=embeddings,
|
||||
operator_kwargs={
|
||||
"top_k": 5,
|
||||
"vector_store_connector": vector_connector,
|
||||
},
|
||||
)
|
||||
results = asyncio.run(evaluator.evaluate(dataset))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operator_cls: Type[RetrieverOperator],
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
operator_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""Create a new RetrieverEvaluator."""
|
||||
if not operator_kwargs:
|
||||
operator_kwargs = {}
|
||||
self._operator_cls = operator_cls
|
||||
self._operator_kwargs: Dict[str, Any] = operator_kwargs
|
||||
self.embeddings = embeddings
|
||||
super().__init__(llm_client=llm_client)
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
metrics: Optional[List[EvaluationMetric]] = None,
|
||||
query_key: str = "query",
|
||||
contexts_key: str = "contexts",
|
||||
prediction_key: str = "prediction",
|
||||
parallel_num: int = 1,
|
||||
**kwargs
|
||||
) -> List[List[EvaluationResult]]:
|
||||
"""Evaluate the dataset."""
|
||||
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
|
||||
|
||||
if not metrics:
|
||||
if not self.embeddings:
|
||||
raise ValueError("embeddings are required for SimilarityMetric")
|
||||
metrics = [RetrieverSimilarityMetric(self.embeddings)]
|
||||
|
||||
with DAG("relevancy_evaluation_dag"):
|
||||
input_task = IteratorTrigger(dataset)
|
||||
query_task: MapOperator = MapOperator(lambda x: x[query_key])
|
||||
retriever_task = self._operator_cls(**self._operator_kwargs)
|
||||
retriever_eva_task = RetrieverEvaluatorOperator(
|
||||
evaluation_metrics=metrics, llm_client=self.llm_client
|
||||
)
|
||||
input_task >> query_task
|
||||
query_task >> retriever_eva_task
|
||||
query_task >> retriever_task >> retriever_eva_task
|
||||
input_task >> MapOperator(lambda x: x[contexts_key]) >> retriever_eva_task
|
||||
input_task >> retriever_eva_task
|
||||
|
||||
results = await input_task.trigger(parallel_num=parallel_num)
|
||||
return [item for _, item in results]
|
@ -154,7 +154,7 @@ class Knowledge(ABC):
|
||||
self._type = knowledge_type
|
||||
self._data_loader = data_loader
|
||||
|
||||
def load(self):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load knowledge from data_loader."""
|
||||
documents = self._load()
|
||||
return self._postprocess(documents)
|
||||
@ -174,7 +174,7 @@ class Knowledge(ABC):
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def _load(self):
|
||||
def _load(self) -> List[Document]:
|
||||
"""Preprocess knowledge from data_loader."""
|
||||
|
||||
@classmethod
|
||||
|
@ -3,6 +3,7 @@
|
||||
from .datasource import DatasourceRetrieverOperator # noqa: F401
|
||||
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
|
||||
from .embedding import EmbeddingRetrieverOperator # noqa: F401
|
||||
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
|
||||
from .knowledge import KnowledgeOperator # noqa: F401
|
||||
from .rerank import RerankOperator # noqa: F401
|
||||
from .rewrite import QueryRewriteOperator # noqa: F401
|
||||
@ -16,4 +17,5 @@ __all__ = [
|
||||
"RerankOperator",
|
||||
"QueryRewriteOperator",
|
||||
"SummaryAssemblerOperator",
|
||||
"RetrieverEvaluatorOperator",
|
||||
]
|
||||
|
@ -1,16 +1,17 @@
|
||||
"""Embedding retriever operator."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import Any, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
|
||||
def __init__(
|
||||
@ -32,7 +33,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
rerank=rerank,
|
||||
)
|
||||
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
def retrieve(self, query: Union[str, List[str]]) -> List[Chunk]:
|
||||
"""Retrieve the candidates."""
|
||||
if isinstance(query, str):
|
||||
return self._retriever.retrieve_with_scores(query, self._score_threshold)
|
||||
|
61
dbgpt/rag/operators/evaluation.py
Normal file
61
dbgpt/rag/operators/evaluation.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Evaluation operators."""
|
||||
import asyncio
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core.awel import JoinOperator
|
||||
from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult
|
||||
from dbgpt.core.interface.llm import LLMClient
|
||||
|
||||
from ..chunk import Chunk
|
||||
|
||||
|
||||
class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]):
|
||||
"""Evaluator for retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evaluation_metrics: List[EvaluationMetric],
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new RetrieverEvaluatorOperator."""
|
||||
self.llm_client = llm_client
|
||||
self.evaluation_metrics = evaluation_metrics
|
||||
super().__init__(combine_function=self._do_evaluation, **kwargs)
|
||||
|
||||
async def _do_evaluation(
|
||||
self,
|
||||
query: str,
|
||||
prediction: List[Chunk],
|
||||
contexts: List[str],
|
||||
raw_dataset: Any = None,
|
||||
) -> List[EvaluationResult]:
|
||||
"""Run evaluation.
|
||||
|
||||
Args:
|
||||
query(str): The query string.
|
||||
prediction(List[Chunk]): The retrieved chunks from the retriever.
|
||||
contexts(List[str]): The contexts from dataset.
|
||||
raw_dataset(Any): The raw data(single row) from dataset.
|
||||
"""
|
||||
if isinstance(contexts, str):
|
||||
contexts = [contexts]
|
||||
prediction_strs = [chunk.content for chunk in prediction]
|
||||
tasks = []
|
||||
for metric in self.evaluation_metrics:
|
||||
tasks.append(metric.compute(prediction_strs, contexts))
|
||||
task_results = await asyncio.gather(*tasks)
|
||||
results = []
|
||||
for result, metric in zip(task_results, self.evaluation_metrics):
|
||||
results.append(
|
||||
EvaluationResult(
|
||||
query=query,
|
||||
prediction=prediction,
|
||||
score=result.score,
|
||||
contexts=contexts,
|
||||
passing=result.passing,
|
||||
raw_dataset=raw_dataset,
|
||||
metric_name=metric.name,
|
||||
)
|
||||
)
|
||||
return results
|
@ -256,6 +256,6 @@ class AgentDummyTrigger(Trigger):
|
||||
"""Initialize a HttpTrigger."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> None:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
pass
|
||||
raise NotImplementedError("Dummy trigger does not support trigger.")
|
||||
|
@ -44,8 +44,10 @@ class BaseAssembler(ABC):
|
||||
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
|
||||
self.load_knowledge(self._knowledge)
|
||||
|
||||
def load_knowledge(self, knowledge) -> None:
|
||||
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
|
||||
"""Load knowledge Pipeline."""
|
||||
if not knowledge:
|
||||
raise ValueError("knowledge must be provided.")
|
||||
with root_tracer.start_span("BaseAssembler.knowledge.load"):
|
||||
documents = knowledge.load()
|
||||
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
|
||||
@ -56,8 +58,12 @@ class BaseAssembler(ABC):
|
||||
"""Return a retriever."""
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, chunks: List[Chunk]) -> None:
|
||||
"""Persist chunks."""
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks.
|
||||
|
||||
Returns:
|
||||
List[str]: List of persisted chunk ids.
|
||||
"""
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunks."""
|
||||
|
@ -129,7 +129,11 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
return self._chunks
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store."""
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
|
@ -29,7 +29,7 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge = None,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
@ -69,7 +69,7 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge = None,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
@ -99,7 +99,11 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store."""
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
|
@ -32,7 +32,7 @@ class SummaryAssembler(BaseAssembler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge = None,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
model_name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
@ -69,7 +69,7 @@ class SummaryAssembler(BaseAssembler):
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge = None,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
model_name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
@ -104,6 +104,7 @@ class SummaryAssembler(BaseAssembler):
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into store."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
|
37
dbgpt/util/similarity_util.py
Normal file
37
dbgpt/util/similarity_util.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Utility functions for calculating similarity."""
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.interface.embeddings import Embeddings
|
||||
|
||||
|
||||
def calculate_cosine_similarity(
|
||||
embeddings: "Embeddings", prediction: str, contexts: Sequence[str]
|
||||
) -> Any:
|
||||
"""Calculate the cosine similarity between a prediction and a list of contexts.
|
||||
|
||||
Args:
|
||||
embeddings(Embeddings): The embeddings to use.
|
||||
prediction(str): The prediction.
|
||||
contexts(Sequence[str]): The contexts.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: The cosine similarity.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("numpy is required for SimilarityMetric")
|
||||
prediction_vec = np.asarray(embeddings.embed_query(prediction)).reshape(1, -1)
|
||||
context_list = list(contexts)
|
||||
context_list_vec = np.asarray(embeddings.embed_documents(context_list)).reshape(
|
||||
len(contexts), -1
|
||||
)
|
||||
# cos(a,b) = dot(a,b) / (norm(a) * norm(b))
|
||||
dot = np.dot(context_list_vec, prediction_vec.T).reshape(
|
||||
-1,
|
||||
)
|
||||
norm = np.linalg.norm(context_list_vec, axis=1) * np.linalg.norm(
|
||||
prediction_vec, axis=1
|
||||
)
|
||||
return dot / norm
|
82
examples/rag/retriever_evaluation_example.py
Normal file
82
examples/rag/retriever_evaluation_example.py
Normal file
@ -0,0 +1,82 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.evaluation import RetrieverEvaluator
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
def _create_embeddings(
|
||||
model_name: Optional[str] = "text2vec-large-chinese",
|
||||
) -> Embeddings:
|
||||
"""Create embeddings."""
|
||||
return DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, model_name),
|
||||
).create()
|
||||
|
||||
|
||||
def _create_vector_connector(
|
||||
embeddings: Embeddings, space_name: str = "retriever_evaluation_example"
|
||||
) -> VectorStoreConnector:
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name=space_name,
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
embeddings = _create_embeddings()
|
||||
vector_connector = _create_vector_connector(embeddings)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
|
||||
dataset = [
|
||||
{
|
||||
"query": "what is awel talk about",
|
||||
"contexts": [
|
||||
"Through the AWEL API, you can focus on the development"
|
||||
" of business logic for LLMs applications without paying "
|
||||
"attention to cumbersome model and environment details."
|
||||
],
|
||||
},
|
||||
]
|
||||
evaluator = RetrieverEvaluator(
|
||||
operator_cls=EmbeddingRetrieverOperator,
|
||||
embeddings=embeddings,
|
||||
operator_kwargs={
|
||||
"top_k": 5,
|
||||
"vector_store_connector": vector_connector,
|
||||
},
|
||||
)
|
||||
results = await evaluator.evaluate(dataset)
|
||||
for result in results:
|
||||
for metric in result:
|
||||
print("Metric:", metric.metric_name)
|
||||
print("Question:", metric.query)
|
||||
print("Score:", metric.score)
|
||||
print(f"Results:\n{results}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -7,7 +7,7 @@
|
||||
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"url": "https://docs.dbgpt.site/docs/awel"
|
||||
"url": "https://docs.dbgpt.site/docs/latest/awel/"
|
||||
}'
|
||||
"""
|
||||
|
||||
|
@ -57,6 +57,8 @@ clean_local_data() {
|
||||
rm -rf /root/DB-GPT/pilot/message
|
||||
rm -f /root/DB-GPT/logs/*
|
||||
rm -f /root/DB-GPT/logsDbChatOutputParser.log
|
||||
rm -rf /root/DB-GPT/pilot/meta_data/alembic/versions/*
|
||||
rm -rf /root/DB-GPT/pilot/meta_data/*.db
|
||||
}
|
||||
|
||||
usage() {
|
||||
|
Loading…
Reference in New Issue
Block a user