feat(rag): Support rag retriever evaluation (#1291)

This commit is contained in:
Fangyin Cheng 2024-03-14 13:06:57 +08:00 committed by GitHub
parent cd2dcc253c
commit adaa68eb00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1452 additions and 67 deletions

View File

@ -278,6 +278,7 @@ CREATE TABLE `dbgpt_serve_flow` (
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
`error_message` varchar(512) NULL comment 'Error message',
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',

View File

@ -0,0 +1,395 @@
-- Full SQL of v0.5.1, please not modify this file(It must be same as the file in the release package)
CREATE
DATABASE IF NOT EXISTS dbgpt;
use dbgpt;
-- For alembic migration tool
CREATE TABLE IF NOT EXISTS `alembic_version`
(
version_num VARCHAR(32) NOT NULL,
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
) DEFAULT CHARSET=utf8mb4 ;
CREATE TABLE IF NOT EXISTS `knowledge_space`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
`desc` varchar(500) NOT NULL COMMENT 'description',
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
`context` TEXT DEFAULT NULL COMMENT 'context argument',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table';
CREATE TABLE IF NOT EXISTS `knowledge_document`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
`chunk_size` int NOT NULL COMMENT 'chunk size',
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table';
CREATE TABLE IF NOT EXISTS `document_chunk`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`document_id` int NOT NULL COMMENT 'document parent id',
`content` longtext NOT NULL COMMENT 'chunk content',
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail';
CREATE TABLE IF NOT EXISTS `connect_config`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`db_type` varchar(255) NOT NULL COMMENT 'db type',
`db_name` varchar(255) NOT NULL COMMENT 'db name',
`db_path` varchar(255) DEFAULT NULL COMMENT 'file db path',
`db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)',
`db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)',
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
`comment` text COMMENT 'db comment',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_db` (`db_name`),
KEY `idx_q_db_type` (`db_type`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi';
CREATE TABLE IF NOT EXISTS `chat_history`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
`message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
UNIQUE KEY `conv_uid` (`conv_uid`),
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
CREATE TABLE IF NOT EXISTS `chat_history_message`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
`index` int NOT NULL COMMENT 'Message index',
`round_index` int NOT NULL COMMENT 'Round of conversation',
`message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
UNIQUE KEY `message_uid_index` (`conv_uid`, `index`),
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message';
CREATE TABLE IF NOT EXISTS `chat_feed_back`
(
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID',
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation',
`score` int(1) DEFAULT NULL COMMENT 'Score of user',
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category',
`question` longtext DEFAULT NULL COMMENT 'User question',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
KEY `idx_conv` (`conv_uid`,`conv_index`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table';
CREATE TABLE IF NOT EXISTS `my_plugin`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant',
`user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name',
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
`file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name',
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
PRIMARY KEY (`id`),
UNIQUE KEY `name` (`name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table';
CREATE TABLE IF NOT EXISTS `plugin_hub`
(
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
`description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description',
`author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author',
`email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email',
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel',
`storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url',
`download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time',
`installed` int DEFAULT NULL COMMENT 'plugin already installed count',
PRIMARY KEY (`id`),
UNIQUE KEY `name` (`name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table';
CREATE TABLE IF NOT EXISTS `prompt_manage`
(
`id` int(11) NOT NULL AUTO_INCREMENT,
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
`prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))',
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
`prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`),
KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
CREATE TABLE IF NOT EXISTS `gpts_conversations` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
`user_goal` text NOT NULL COMMENT 'User''s goals content',
`gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name',
`state` varchar(255) DEFAULT NULL COMMENT 'The gpts state',
`max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round',
`auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count',
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
`team_mode` varchar(255) NULL COMMENT 'agent team work mode',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_gpts_conversations` (`conv_id`),
KEY `idx_gpts_name` (`gpts_name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations";
CREATE TABLE IF NOT EXISTS `gpts_instance` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
`gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe',
`resource_db` text COMMENT 'List of structured database names contained in the current gpts',
`resource_internet` text COMMENT 'Is it possible to retrieve information from the internet',
`resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts',
`gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts',
`gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts',
`language` varchar(100) DEFAULT NULL COMMENT 'gpts language',
`user_code` varchar(255) NOT NULL COMMENT 'user code',
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
`team_mode` varchar(255) NOT NULL COMMENT 'Team work mode',
`is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_gpts` (`gpts_name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance";
CREATE TABLE `gpts_messages` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
`sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn',
`receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn',
`model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model',
`rounds` int(11) NOT NULL COMMENT 'dialogue turns',
`content` text COMMENT 'Content of the speech',
`current_goal` text COMMENT 'The target corresponding to the current message',
`context` text COMMENT 'Current conversation context',
`review_info` text COMMENT 'Current conversation review info',
`action_report` text COMMENT 'Current conversation action report',
`role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
PRIMARY KEY (`id`),
KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message";
CREATE TABLE `gpts_plans` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
`sub_task_num` int(11) NOT NULL COMMENT 'Subtask number',
`sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title',
`sub_task_content` text NOT NULL COMMENT 'subtask content',
`sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks',
`resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name',
`rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencieslike: 1,2,3',
`agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents',
`retry_times` int(11) DEFAULT NULL COMMENT 'number of retries',
`max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries',
`state` varchar(255) DEFAULT NULL COMMENT 'subtask status',
`result` longtext COMMENT 'subtask result',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan";
-- dbgpt.dbgpt_serve_flow definition
CREATE TABLE `dbgpt_serve_flow` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id',
`uid` varchar(128) NOT NULL COMMENT 'Unique id',
`dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id',
`name` varchar(128) DEFAULT NULL COMMENT 'Flow name',
`flow_data` text COMMENT 'Flow data, JSON format',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time',
`gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time',
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
`error_message` varchar(512) NULL comment 'Error message',
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
`label` varchar(128) DEFAULT NULL COMMENT 'Flow label',
`editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_uid` (`uid`),
KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`),
KEY `ix_dbgpt_serve_flow_uid` (`uid`),
KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`),
KEY `ix_dbgpt_serve_flow_user_name` (`user_name`),
KEY `ix_dbgpt_serve_flow_name` (`name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
-- dbgpt.gpts_app definition
CREATE TABLE `gpts_app` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
`app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
`app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe',
`language` varchar(100) NOT NULL COMMENT 'gpts language',
`team_mode` varchar(255) NOT NULL COMMENT 'Team work mode',
`team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on',
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
`icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_gpts_app` (`app_name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE TABLE `gpts_app_collection` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
`user_code` int(11) NOT NULL COMMENT 'user code',
`sys_code` varchar(255) NOT NULL COMMENT 'system app code',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
PRIMARY KEY (`id`),
KEY `idx_app_code` (`app_code`),
KEY `idx_user_code` (`user_code`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections";
-- dbgpt.gpts_app_detail definition
CREATE TABLE `gpts_app_detail` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
`app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
`agent_name` varchar(255) NOT NULL COMMENT ' Agent name',
`node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id',
`resources` text COMMENT 'Agent bind resource',
`prompt_template` text COMMENT 'Agent bind template',
`llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy',
`llm_strategy_value` text COMMENT 'Agent use llm strategy value',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
CREATE
DATABASE IF NOT EXISTS EXAMPLE_1;
use EXAMPLE_1;
CREATE TABLE IF NOT EXISTS `users`
(
`id` int NOT NULL AUTO_INCREMENT,
`username` varchar(50) NOT NULL COMMENT '用户名',
`password` varchar(50) NOT NULL COMMENT '密码',
`email` varchar(50) NOT NULL COMMENT '邮箱',
`phone` varchar(20) DEFAULT NULL COMMENT '电话',
PRIMARY KEY (`id`),
KEY `idx_username` (`username`) COMMENT '索引:按用户名查询'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表';
INSERT INTO users (username, password, email, phone)
VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901');
INSERT INTO users (username, password, email, phone)
VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902');
INSERT INTO users (username, password, email, phone)
VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903');
INSERT INTO users (username, password, email, phone)
VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904');
INSERT INTO users (username, password, email, phone)
VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905');
INSERT INTO users (username, password, email, phone)
VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906');
INSERT INTO users (username, password, email, phone)
VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907');
INSERT INTO users (username, password, email, phone)
VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908');
INSERT INTO users (username, password, email, phone)
VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909');
INSERT INTO users (username, password, email, phone)
VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900');
INSERT INTO users (username, password, email, phone)
VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901');
INSERT INTO users (username, password, email, phone)
VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902');
INSERT INTO users (username, password, email, phone)
VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903');
INSERT INTO users (username, password, email, phone)
VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904');
INSERT INTO users (username, password, email, phone)
VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905');
INSERT INTO users (username, password, email, phone)
VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906');
INSERT INTO users (username, password, email, phone)
VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907');
INSERT INTO users (username, password, email, phone)
VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908');
INSERT INTO users (username, password, email, phone)
VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909');
INSERT INTO users (username, password, email, phone)
VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900');

View File

@ -167,6 +167,8 @@ EMBEDDING_MODEL_CONFIG = {
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
# https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True
"bge-m3": os.path.join(MODEL_PATH, "bge-m3"),
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),

View File

@ -7,6 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401
CachePolicy,
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
@ -103,4 +104,5 @@ __ALL__ = [
"DefaultStorageItemAdapter",
"QuerySpec",
"StorageError",
"Embeddings",
]

View File

@ -55,6 +55,7 @@ from .trigger.http_trigger import (
CommonLLMHttpResponseBody,
HttpTrigger,
)
from .trigger.iterator_trigger import IteratorTrigger
_request_http_trigger_available = False
try:
@ -100,6 +101,7 @@ __all__ = [
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"IteratorTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
"CommonLLMHttpRequestBody",

View File

@ -277,7 +277,7 @@ class InputOperator(BaseOperator, Generic[OUT]):
return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
"""Operator node that triggers the DAG to run."""
def __init__(self, **kwargs) -> None:

View File

@ -60,8 +60,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
)
if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
# if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
@ -76,8 +76,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
# if node.dag:
# del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx
async def _execute_node(

View File

@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
Returns:
TaskOutput[T]: The output object read from current source
"""
@classmethod
def from_data(cls, data: T) -> "InputSource[T]":
"""Create an InputSource from data.
Args:
data (T): The data to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the data.
"""
from .task_impl import SimpleInputSource
return SimpleInputSource(data, streaming=False)
@classmethod
def from_iterable(
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
) -> "InputSource[T]":
"""Create an InputSource from an iterable.
Args:
iterable (List[T]): The iterable to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the iterable.
"""
from .task_impl import SimpleInputSource
return SimpleInputSource(iterable, streaming=True)
@classmethod
def from_callable(cls) -> "InputSource[T]":
"""Create an InputSource from a callable."""
from .task_impl import SimpleCallDataInputSource
return SimpleCallDataInputSource()

View File

@ -261,13 +261,42 @@ def _is_async_iterator(obj):
)
def _is_async_iterable(obj):
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
def _is_iterator(obj):
return (
hasattr(obj, "__iter__")
and callable(getattr(obj, "__iter__", None))
and hasattr(obj, "__next__")
and callable(getattr(obj, "__next__", None))
)
def _is_iterable(obj):
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
async def _to_async_iterator(obj) -> AsyncIterator:
if _is_async_iterable(obj):
async for item in obj:
yield item
elif _is_iterable(obj):
for item in obj:
yield item
else:
raise ValueError(f"Can not convert {obj} to AsyncIterator")
class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""
def __init__(self) -> None:
def __init__(self, streaming: Optional[bool] = None) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
self._streaming_data = streaming
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._streaming_data is None:
streaming_data = _is_async_iterator(data) or _is_iterator(data)
else:
streaming_data = self._streaming_data
if streaming_data:
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output: TaskOutput = SimpleStreamTaskOutput(data)
it_data = _to_async_iterator(data)
output: TaskOutput = SimpleStreamTaskOutput(it_data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""
def __init__(self, data: Any) -> None:
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
super().__init__(streaming=streaming)
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any:

View File

@ -0,0 +1,118 @@
from typing import AsyncIterator
import pytest
from dbgpt.core.awel import (
DAG,
InputSource,
MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger
class NumberProducerOperator(StreamifyAbsOperator[int, int]):
"""Create a stream of numbers from 0 to `n-1`"""
async def streamify(self, n: int) -> AsyncIterator[int]:
for i in range(n):
yield i
class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in data:
yield i * i
async def _check_stream_results(stream_results, expected_len):
assert len(stream_results) == expected_len
for _, result in stream_results:
i = 0
async for num in result:
assert num == i * i
i += 1
@pytest.mark.asyncio
async def test_single_data():
with DAG("test_single_data"):
trigger_task = IteratorTrigger(data=2)
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 1
assert results[0][1] == 4
with DAG("test_single_data_stream"):
trigger_task = IteratorTrigger(data=2, streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 1)
@pytest.mark.asyncio
async def test_list_data():
with DAG("test_list_data"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
with DAG("test_list_data_stream"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)
@pytest.mark.asyncio
async def test_async_iterator_data():
async def async_iter():
for i in range(4):
yield i
with DAG("test_async_iterator_data"):
trigger_task = IteratorTrigger(data=async_iter())
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
with DAG("test_async_iterator_data_stream"):
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)
@pytest.mark.asyncio
async def test_input_source_data():
with DAG("test_input_source_data"):
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
with DAG("test_input_source_data_stream"):
trigger_task = IteratorTrigger(
data=InputSource.from_iterable([0, 1, 2, 3]),
streaming_call=True,
)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)

View File

@ -2,16 +2,18 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic
from ..operators.common_operator import TriggerOperator
from ..task.base import OUT
class Trigger(TriggerOperator, ABC):
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
"""Base class for all trigger classes.
Now only support http trigger.
"""
@abstractmethod
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -397,9 +397,9 @@ class HttpTrigger(Trigger):
self._end_node: Optional[BaseOperator] = None
self._register_to_app = register_to_app
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
raise NotImplementedError("HttpTrigger does not support trigger directly")
def register_to_app(self) -> bool:
"""Register the trigger to a FastAPI app.

View File

@ -0,0 +1,148 @@
"""Trigger for iterator data."""
import asyncio
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast
from ..operators.base import BaseOperator
from ..task.base import InputSource, TaskState
from ..task.task_impl import DefaultTaskContext, _is_async_iterator, _is_iterable
from .base import Trigger
IterDataType = Union[InputSource, Iterator, AsyncIterator, Any]
async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIterator:
"""Convert iter_data to an async iterator."""
if _is_async_iterator(iter_data):
async for item in iter_data: # type: ignore
yield item
elif _is_iterable(iter_data):
for item in iter_data: # type: ignore
yield item
elif isinstance(iter_data, InputSource):
task_ctx: DefaultTaskContext[Any] = DefaultTaskContext(
task_id, TaskState.RUNNING, None
)
data = await iter_data.read(task_ctx)
if data.is_stream:
async for item in data.output_stream:
yield item
else:
yield data.output
else:
yield iter_data
class IteratorTrigger(Trigger):
"""Trigger for iterator data.
Trigger the dag with iterator data.
Return the list of results of the leaf nodes in the dag.
The times of dag running is the length of the iterator data.
"""
def __init__(
self,
data: IterDataType,
parallel_num: int = 1,
streaming_call: bool = False,
**kwargs
):
"""Create a IteratorTrigger.
Args:
data (IterDataType): The iterator data.
parallel_num (int, optional): The parallel number of the dag running.
Defaults to 1.
streaming_call (bool, optional): Whether the dag is a streaming call.
Defaults to False.
"""
self._iter_data = data
self._parallel_num = parallel_num
self._streaming_call = streaming_call
super().__init__(**kwargs)
async def trigger(
self, parallel_num: Optional[int] = None, **kwargs
) -> List[Tuple[Any, Any]]:
"""Trigger the dag with iterator data.
If the dag is a streaming call, return the list of async iterator.
Examples:
.. code-block:: python
import asyncio
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
with DAG("test_dag") as dag:
trigger_task = IteratorTrigger([0, 1, 2, 3])
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = asyncio.run(trigger_task.trigger())
# Fist element of the tuple is the input data, the second element is the
# output data of the leaf node.
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
.. code-block:: python
import asyncio
from datasets import Dataset
from dbgpt.core.awel import (
DAG,
IteratorTrigger,
MapOperator,
InputSource,
)
data_samples = {
"question": ["What is 1+1?", "What is 7*7?"],
"answer": [2, 49],
}
dataset = Dataset.from_dict(data_samples)
with DAG("test_dag_stream") as dag:
trigger_task = IteratorTrigger(InputSource.from_iterable(dataset))
task = MapOperator(lambda x: x["answer"])
trigger_task >> task
results = asyncio.run(trigger_task.trigger())
assert results == [
({"question": "What is 1+1?", "answer": 2}, 2),
({"question": "What is 7*7?", "answer": 49}, 49),
]
Args:
parallel_num (Optional[int], optional): The parallel number of the dag
running. Defaults to None.
Returns:
List[Tuple[Any, Any]]: The list of results of the leaf nodes in the dag.
The first element of the tuple is the input data, the second element is
the output data of the leaf node.
"""
dag = self.dag
if not dag:
raise ValueError("DAG is not set for IteratorTrigger")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("IteratorTrigger just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
streaming_call = self._streaming_call
semaphore = asyncio.Semaphore(parallel_num or self._parallel_num)
task_id = self.node_id
async def call_stream(call_data: Any):
async for out in await end_node.call_stream(call_data):
yield out
async def run_node(call_data: Any):
async with semaphore:
if streaming_call:
task_output = call_stream(call_data)
else:
task_output = await end_node.call(call_data)
return call_data, task_output
tasks = []
async for data in _to_async_iterator(self._iter_data, task_id):
tasks.append(run_node(data))
results = await asyncio.gather(*tasks)
return results

View File

@ -0,0 +1,32 @@
"""Interface for embedding models."""
import asyncio
from abc import ABC, abstractmethod
from typing import List
class Embeddings(ABC):
"""Interface for embedding models.
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
master/libs/langchain/langchain/embeddings>`_.
"""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)

View File

@ -0,0 +1,253 @@
"""Evaluation module."""
import asyncio
import string
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Generic,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
)
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.util.similarity_util import calculate_cosine_similarity
from .embeddings import Embeddings
from .llm import LLMClient
if TYPE_CHECKING:
from dbgpt.core.awel.task.base import InputSource
QueryType = Union[str, Any]
PredictionType = Union[str, Any]
ContextType = Union[str, Sequence[str], Any]
DatasetType = Union["InputSource", Iterator, AsyncIterator]
class BaseEvaluationResult(BaseModel):
"""Base evaluation result."""
prediction: Optional[PredictionType] = Field(
None,
description="Prediction data(including the output of LLM, the data from "
"retrieval, etc.)",
)
contexts: Optional[ContextType] = Field(None, description="Context data")
score: Optional[float] = Field(None, description="Score for the prediction")
passing: Optional[bool] = Field(
None, description="Binary evaluation result (passing or not)"
)
metric_name: Optional[str] = Field(None, description="Name of the metric")
class EvaluationResult(BaseEvaluationResult):
"""Evaluation result.
Output of an BaseEvaluator.
"""
query: Optional[QueryType] = Field(None, description="Query data")
raw_dataset: Optional[Any] = Field(None, description="Raw dataset")
Q = TypeVar("Q")
P = TypeVar("P")
C = TypeVar("C")
class EvaluationMetric(ABC, Generic[P, C]):
"""Base class for evaluation metric."""
@property
def name(self) -> str:
"""Name of the metric."""
return self.__class__.__name__
async def compute(
self,
prediction: P,
contexts: Optional[Sequence[C]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The context data.
Returns:
BaseEvaluationResult: The evaluation result.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_compute, prediction, contexts
)
def sync_compute(
self,
prediction: P,
contexts: Optional[Sequence[C]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The context data.
Returns:
BaseEvaluationResult: The evaluation result.
"""
raise NotImplementedError("sync_compute is not implemented")
class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
"""Evaluation metric based on a function."""
def __init__(
self,
name: str,
func: Callable[
[P, Optional[Sequence[C]]],
BaseEvaluationResult,
],
):
"""Create a FunctionMetric.
Args:
name(str): The name of the metric.
func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]):
The function to use for evaluation.
"""
self._name = name
self.func = func
@property
def name(self) -> str:
"""Name of the metric."""
return self._name
async def compute(
self,
prediction: P,
context: Optional[Sequence[C]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
return self.func(prediction, context)
class ExactMatchMetric(EvaluationMetric[str, str]):
"""Exact match metric.
Just support string prediction and context.
"""
def __init__(self, ignore_case: bool = False, ignore_punctuation: bool = False):
"""Create an ExactMatchMetric."""
self._ignore_case = ignore_case
self._ignore_punctuation = ignore_punctuation
async def compute(
self,
prediction: str,
contexts: Optional[Sequence[str]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
if self._ignore_case:
prediction = prediction.lower()
if contexts:
contexts = [c.lower() for c in contexts]
if self._ignore_punctuation:
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
if contexts:
contexts = [
c.translate(str.maketrans("", "", string.punctuation))
for c in contexts
]
score = 0 if not contexts else float(prediction in contexts)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=score,
)
class SimilarityMetric(EvaluationMetric[str, str]):
"""Similarity metric.
Calculate the cosine similarity between a prediction and a list of contexts.
"""
def __init__(self, embeddings: Embeddings):
"""Create a SimilarityMetric with embeddings."""
self._embeddings = embeddings
def sync_compute(
self,
prediction: str,
contexts: Optional[Sequence[str]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
if not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for SimilarityMetric")
similarity: np.ndarray = calculate_cosine_similarity(
self._embeddings, prediction, contexts
)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=float(similarity.mean()),
)
class Evaluator(ABC):
"""Base Evaluator class."""
def __init__(
self,
llm_client: Optional[LLMClient] = None,
):
"""Create an Evaluator."""
self.llm_client = llm_client
@abstractmethod
async def evaluate(
self,
dataset: DatasetType,
metrics: Optional[List[EvaluationMetric]] = None,
query_key: str = "query",
contexts_key: str = "contexts",
prediction_key: str = "prediction",
parallel_num: int = 1,
**kwargs
) -> List[List[EvaluationResult]]:
"""Run evaluation with a dataset and metrics.
Args:
dataset(DatasetType): The dataset to evaluate.
metrics(Optional[List[EvaluationMetric]]): The metrics to use for
evaluation.
query_key(str): The key for query in the dataset.
contexts_key(str): The key for contexts in the dataset.
prediction_key(str): The key for prediction in the dataset.
parallel_num(int): The number of parallel tasks.
kwargs: Additional arguments.
Returns:
List[List[EvaluationResult]]: The evaluation results, the length of the
result equals to the length of the dataset. The first element in the
list is the list of evaluation results for metrics.
"""

View File

@ -5,9 +5,10 @@ from typing import Any, List, Optional
from pydantic import BaseModel, Field
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk import Chunk, Document
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
from dbgpt.rag.text_splitter import TextSplitter
class SplitterType(Enum):
@ -81,14 +82,14 @@ class ChunkManager:
self._text_splitter = self._chunk_parameters.text_splitter
self._splitter_type = self._chunk_parameters.splitter_type
def split(self, documents) -> List[Chunk]:
def split(self, documents: List[Document]) -> List[Chunk]:
"""Split a document into chunks."""
text_splitter = self._select_text_splitter()
if SplitterType.LANGCHAIN == self._splitter_type:
documents = text_splitter.split_documents(documents)
return [Chunk.langchain2chunk(document) for document in documents]
elif SplitterType.LLAMA_INDEX == self._splitter_type:
nodes = text_splitter.split_text(documents)
nodes = text_splitter.split_documents(documents)
return [Chunk.llamaindex2chunk(node) for node in nodes]
else:
return text_splitter.split_documents(documents)
@ -106,7 +107,7 @@ class ChunkManager:
def set_text_splitter(
self,
text_splitter,
text_splitter: TextSplitter,
splitter_type: SplitterType = SplitterType.LANGCHAIN,
) -> None:
"""Add text splitter."""
@ -115,13 +116,13 @@ class ChunkManager:
def get_text_splitter(
self,
) -> Any:
) -> TextSplitter:
"""Return text splitter."""
return self._select_text_splitter()
def _select_text_splitter(
self,
):
) -> TextSplitter:
"""Select text splitter by chunk strategy."""
if self._text_splitter:
return self._text_splitter

View File

@ -1,13 +1,13 @@
"""Embedding implementations."""
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from dbgpt._private.pydantic import BaseModel, Extra, Field
from dbgpt.core import Embeddings
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@ -22,34 +22,6 @@ DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
class Embeddings(ABC):
"""Interface for embedding models.
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
master/libs/langchain/langchain/embeddings>`_.
"""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)
class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""HuggingFace sentence_transformers embedding models.

View File

@ -0,0 +1,13 @@
"""Module for evaluation of RAG."""
from .retriever import ( # noqa: F401
RetrieverEvaluationMetric,
RetrieverEvaluator,
RetrieverSimilarityMetric,
)
__ALL__ = [
"RetrieverEvaluator",
"RetrieverSimilarityMetric",
"RetrieverEvaluationMetric",
]

View File

@ -0,0 +1,171 @@
"""Evaluation for retriever."""
from abc import ABC
from typing import Any, Dict, List, Optional, Sequence, Type
from dbgpt.core import Embeddings, LLMClient
from dbgpt.core.interface.evaluation import (
BaseEvaluationResult,
DatasetType,
EvaluationMetric,
EvaluationResult,
Evaluator,
)
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.util.similarity_util import calculate_cosine_similarity
from ..operators.evaluation import RetrieverEvaluatorOperator
class RetrieverEvaluationMetric(EvaluationMetric[List[str], str], ABC):
"""Evaluation metric for retriever.
The prediction is a list of str(content from chunks) and the context is a string.
"""
class RetrieverSimilarityMetric(RetrieverEvaluationMetric):
"""Similarity metric for retriever."""
def __init__(self, embeddings: Embeddings):
"""Create a SimilarityMetric with embeddings."""
self._embeddings = embeddings
def sync_compute(
self,
prediction: List[str],
contexts: Optional[Sequence[str]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(List[str]): The retrieved chunks from the retriever.
contexts(Sequence[str]): The contexts from dataset.
Returns:
BaseEvaluationResult: The evaluation result.
The score is the mean of the cosine similarity between the prediction
and the contexts.
"""
if not prediction or not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for RelevancySimilarityMetric")
similarity: np.ndarray = calculate_cosine_similarity(
self._embeddings, contexts[0], prediction
)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=float(similarity.mean()),
)
class RetrieverEvaluator(Evaluator):
"""Evaluator for relevancy.
Examples:
.. code-block:: python
import os
import asyncio
from dbgpt.rag.operators import (
EmbeddingRetrieverOperator,
RetrieverEvaluatorOperator,
)
from dbgpt.rag.evaluation import (
RetrieverEvaluator,
RetrieverSimilarityMetric,
)
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
embeddings = DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create()
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="my_test_schema",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=embeddings,
)
dataset = [
{
"query": "what is awel talk about",
"contexts": [
"Through the AWEL API, you can focus on the development"
" of business logic for LLMs applications without paying "
"attention to cumbersome model and environment details."
],
},
]
evaluator = RetrieverEvaluator(
operator_cls=EmbeddingRetrieverOperator,
embeddings=embeddings,
operator_kwargs={
"top_k": 5,
"vector_store_connector": vector_connector,
},
)
results = asyncio.run(evaluator.evaluate(dataset))
"""
def __init__(
self,
operator_cls: Type[RetrieverOperator],
llm_client: Optional[LLMClient] = None,
embeddings: Optional[Embeddings] = None,
operator_kwargs: Optional[Dict] = None,
):
"""Create a new RetrieverEvaluator."""
if not operator_kwargs:
operator_kwargs = {}
self._operator_cls = operator_cls
self._operator_kwargs: Dict[str, Any] = operator_kwargs
self.embeddings = embeddings
super().__init__(llm_client=llm_client)
async def evaluate(
self,
dataset: DatasetType,
metrics: Optional[List[EvaluationMetric]] = None,
query_key: str = "query",
contexts_key: str = "contexts",
prediction_key: str = "prediction",
parallel_num: int = 1,
**kwargs
) -> List[List[EvaluationResult]]:
"""Evaluate the dataset."""
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
if not metrics:
if not self.embeddings:
raise ValueError("embeddings are required for SimilarityMetric")
metrics = [RetrieverSimilarityMetric(self.embeddings)]
with DAG("relevancy_evaluation_dag"):
input_task = IteratorTrigger(dataset)
query_task: MapOperator = MapOperator(lambda x: x[query_key])
retriever_task = self._operator_cls(**self._operator_kwargs)
retriever_eva_task = RetrieverEvaluatorOperator(
evaluation_metrics=metrics, llm_client=self.llm_client
)
input_task >> query_task
query_task >> retriever_eva_task
query_task >> retriever_task >> retriever_eva_task
input_task >> MapOperator(lambda x: x[contexts_key]) >> retriever_eva_task
input_task >> retriever_eva_task
results = await input_task.trigger(parallel_num=parallel_num)
return [item for _, item in results]

View File

@ -154,7 +154,7 @@ class Knowledge(ABC):
self._type = knowledge_type
self._data_loader = data_loader
def load(self):
def load(self) -> List[Document]:
"""Load knowledge from data_loader."""
documents = self._load()
return self._postprocess(documents)
@ -174,7 +174,7 @@ class Knowledge(ABC):
return docs
@abstractmethod
def _load(self):
def _load(self) -> List[Document]:
"""Preprocess knowledge from data_loader."""
@classmethod

View File

@ -3,6 +3,7 @@
from .datasource import DatasourceRetrieverOperator # noqa: F401
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
from .embedding import EmbeddingRetrieverOperator # noqa: F401
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
from .knowledge import KnowledgeOperator # noqa: F401
from .rerank import RerankOperator # noqa: F401
from .rewrite import QueryRewriteOperator # noqa: F401
@ -16,4 +17,5 @@ __all__ = [
"RerankOperator",
"QueryRewriteOperator",
"SummaryAssemblerOperator",
"RetrieverEvaluatorOperator",
]

View File

@ -1,16 +1,17 @@
"""Embedding retriever operator."""
from functools import reduce
from typing import Any, Optional
from typing import List, Optional, Union
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
"""The Embedding Retriever Operator."""
def __init__(
@ -32,7 +33,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
rerank=rerank,
)
def retrieve(self, query: Any) -> Any:
def retrieve(self, query: Union[str, List[str]]) -> List[Chunk]:
"""Retrieve the candidates."""
if isinstance(query, str):
return self._retriever.retrieve_with_scores(query, self._score_threshold)

View File

@ -0,0 +1,61 @@
"""Evaluation operators."""
import asyncio
from typing import Any, List, Optional
from dbgpt.core.awel import JoinOperator
from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult
from dbgpt.core.interface.llm import LLMClient
from ..chunk import Chunk
class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]):
"""Evaluator for retriever."""
def __init__(
self,
evaluation_metrics: List[EvaluationMetric],
llm_client: Optional[LLMClient] = None,
**kwargs,
):
"""Create a new RetrieverEvaluatorOperator."""
self.llm_client = llm_client
self.evaluation_metrics = evaluation_metrics
super().__init__(combine_function=self._do_evaluation, **kwargs)
async def _do_evaluation(
self,
query: str,
prediction: List[Chunk],
contexts: List[str],
raw_dataset: Any = None,
) -> List[EvaluationResult]:
"""Run evaluation.
Args:
query(str): The query string.
prediction(List[Chunk]): The retrieved chunks from the retriever.
contexts(List[str]): The contexts from dataset.
raw_dataset(Any): The raw data(single row) from dataset.
"""
if isinstance(contexts, str):
contexts = [contexts]
prediction_strs = [chunk.content for chunk in prediction]
tasks = []
for metric in self.evaluation_metrics:
tasks.append(metric.compute(prediction_strs, contexts))
task_results = await asyncio.gather(*tasks)
results = []
for result, metric in zip(task_results, self.evaluation_metrics):
results.append(
EvaluationResult(
query=query,
prediction=prediction,
score=result.score,
contexts=contexts,
passing=result.passing,
raw_dataset=raw_dataset,
metric_name=metric.name,
)
)
return results

View File

@ -256,6 +256,6 @@ class AgentDummyTrigger(Trigger):
"""Initialize a HttpTrigger."""
super().__init__(**kwargs)
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> None:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
raise NotImplementedError("Dummy trigger does not support trigger.")

View File

@ -44,8 +44,10 @@ class BaseAssembler(ABC):
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
self.load_knowledge(self._knowledge)
def load_knowledge(self, knowledge) -> None:
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
"""Load knowledge Pipeline."""
if not knowledge:
raise ValueError("knowledge must be provided.")
with root_tracer.start_span("BaseAssembler.knowledge.load"):
documents = knowledge.load()
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
@ -56,8 +58,12 @@ class BaseAssembler(ABC):
"""Return a retriever."""
@abstractmethod
def persist(self, chunks: List[Chunk]) -> None:
"""Persist chunks."""
def persist(self) -> List[str]:
"""Persist chunks.
Returns:
List[str]: List of persisted chunk ids.
"""
def get_chunks(self) -> List[Chunk]:
"""Return chunks."""

View File

@ -129,7 +129,11 @@ class DBSchemaAssembler(BaseAssembler):
return self._chunks
def persist(self) -> List[str]:
"""Persist chunks into vector store."""
"""Persist chunks into vector store.
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:

View File

@ -29,7 +29,7 @@ class EmbeddingAssembler(BaseAssembler):
def __init__(
self,
knowledge: Knowledge = None,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
@ -69,7 +69,7 @@ class EmbeddingAssembler(BaseAssembler):
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge = None,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
@ -99,7 +99,11 @@ class EmbeddingAssembler(BaseAssembler):
)
def persist(self) -> List[str]:
"""Persist chunks into vector store."""
"""Persist chunks into vector store.
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:

View File

@ -32,7 +32,7 @@ class SummaryAssembler(BaseAssembler):
def __init__(
self,
knowledge: Knowledge = None,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
model_name: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
@ -69,7 +69,7 @@ class SummaryAssembler(BaseAssembler):
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge = None,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
model_name: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
@ -104,6 +104,7 @@ class SummaryAssembler(BaseAssembler):
def persist(self) -> List[str]:
"""Persist chunks into store."""
raise NotImplementedError
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""

View File

@ -0,0 +1,37 @@
"""Utility functions for calculating similarity."""
from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from dbgpt.core.interface.embeddings import Embeddings
def calculate_cosine_similarity(
embeddings: "Embeddings", prediction: str, contexts: Sequence[str]
) -> Any:
"""Calculate the cosine similarity between a prediction and a list of contexts.
Args:
embeddings(Embeddings): The embeddings to use.
prediction(str): The prediction.
contexts(Sequence[str]): The contexts.
Returns:
numpy.ndarray: The cosine similarity.
"""
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for SimilarityMetric")
prediction_vec = np.asarray(embeddings.embed_query(prediction)).reshape(1, -1)
context_list = list(contexts)
context_list_vec = np.asarray(embeddings.embed_documents(context_list)).reshape(
len(contexts), -1
)
# cos(a,b) = dot(a,b) / (norm(a) * norm(b))
dot = np.dot(context_list_vec, prediction_vec.T).reshape(
-1,
)
norm = np.linalg.norm(context_list_vec, axis=1) * np.linalg.norm(
prediction_vec, axis=1
)
return dot / norm

View File

@ -0,0 +1,82 @@
import asyncio
import os
from typing import Optional
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
from dbgpt.core import Embeddings
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.evaluation import RetrieverEvaluator
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.rag.operators import EmbeddingRetrieverOperator
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_embeddings(
model_name: Optional[str] = "text2vec-large-chinese",
) -> Embeddings:
"""Create embeddings."""
return DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
).create()
def _create_vector_connector(
embeddings: Embeddings, space_name: str = "retriever_evaluation_example"
) -> VectorStoreConnector:
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name=space_name,
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=embeddings,
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
embeddings = _create_embeddings()
vector_connector = _create_vector_connector(embeddings)
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
)
assembler.persist()
dataset = [
{
"query": "what is awel talk about",
"contexts": [
"Through the AWEL API, you can focus on the development"
" of business logic for LLMs applications without paying "
"attention to cumbersome model and environment details."
],
},
]
evaluator = RetrieverEvaluator(
operator_cls=EmbeddingRetrieverOperator,
embeddings=embeddings,
operator_kwargs={
"top_k": 5,
"vector_store_connector": vector_connector,
},
)
results = await evaluator.evaluate(dataset)
for result in results:
for metric in result:
print("Metric:", metric.metric_name)
print("Question:", metric.query)
print("Score:", metric.score)
print(f"Results:\n{results}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -7,7 +7,7 @@
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
--header 'Content-Type: application/json' \
--data-raw '{
"url": "https://docs.dbgpt.site/docs/awel"
"url": "https://docs.dbgpt.site/docs/latest/awel/"
}'
"""

View File

@ -57,6 +57,8 @@ clean_local_data() {
rm -rf /root/DB-GPT/pilot/message
rm -f /root/DB-GPT/logs/*
rm -f /root/DB-GPT/logsDbChatOutputParser.log
rm -rf /root/DB-GPT/pilot/meta_data/alembic/versions/*
rm -rf /root/DB-GPT/pilot/meta_data/*.db
}
usage() {