diff --git a/.env.template b/.env.template index f4213accb..2cf4510c9 100644 --- a/.env.template +++ b/.env.template @@ -272,6 +272,17 @@ DBGPT_LOG_LEVEL=INFO # API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas. # API_KEYS=dbgpt +#*******************************************************************# +#** ENCRYPT **# +#*******************************************************************# +# ENCRYPT KEY - The key used to encrypt and decrypt the data +# ENCRYPT_KEY=your_secret_key + +#*******************************************************************# +#** File Server **# +#*******************************************************************# +## The local storage path of the file server, the default is pilot/data/file_server +# FILE_SERVER_LOCAL_STORAGE_PATH = #*******************************************************************# #** Application Config **# diff --git a/.mypy.ini b/.mypy.ini index b8e214221..d9e3a7589 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -119,3 +119,6 @@ ignore_missing_imports = True [mypy-networkx.*] ignore_missing_imports = True + +[mypy-pypdf.*] +ignore_missing_imports = True diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index 0cdd7d17e..0d6e1c91b 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -32,12 +32,14 @@ CREATE TABLE IF NOT EXISTS `knowledge_document` `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', `doc_name` varchar(100) NOT NULL COMMENT 'document path name', `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `doc_token` varchar(100) NOT NULL COMMENT 'doc token', `space` varchar(50) NOT NULL COMMENT 'knowledge space', `chunk_size` int NOT NULL COMMENT 'chunk size', `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `result` TEXT NULL COMMENT 'knowledge content', + `questions` TEXT NULL COMMENT 'document related questions', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', `summary` LONGTEXT NULL COMMENT 'knowledge summary', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', @@ -53,6 +55,7 @@ CREATE TABLE IF NOT EXISTS `document_chunk` `doc_type` varchar(50) NOT NULL COMMENT 'doc type', `document_id` int NOT NULL COMMENT 'document parent id', `content` longtext NOT NULL COMMENT 'chunk content', + `questions` text NULL COMMENT 'chunk related questions', `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -61,7 +64,6 @@ CREATE TABLE IF NOT EXISTS `document_chunk` ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; - CREATE TABLE IF NOT EXISTS `connect_config` ( `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', @@ -74,6 +76,9 @@ CREATE TABLE IF NOT EXISTS `connect_config` `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', `comment` text COMMENT 'db comment', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `user_name` varchar(255) DEFAULT NULL COMMENT 'user name', + `user_id` varchar(255) DEFAULT NULL COMMENT 'user id', + PRIMARY KEY (`id`), UNIQUE KEY `uk_db` (`db_name`), KEY `idx_q_db_type` (`db_type`) @@ -88,11 +93,13 @@ CREATE TABLE IF NOT EXISTS `chat_history` `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `app_code` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'App unique code', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', UNIQUE KEY `conv_uid` (`conv_uid`), - PRIMARY KEY (`id`) + PRIMARY KEY (`id`), + KEY `idx_chat_his_app_code` (`app_code`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; CREATE TABLE IF NOT EXISTS `chat_history_message` @@ -108,6 +115,7 @@ CREATE TABLE IF NOT EXISTS `chat_history_message` PRIMARY KEY (`id`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + CREATE TABLE IF NOT EXISTS `chat_feed_back` ( `id` bigint(20) NOT NULL AUTO_INCREMENT, @@ -118,6 +126,11 @@ CREATE TABLE IF NOT EXISTS `chat_feed_back` `question` longtext DEFAULT NULL COMMENT 'User question', `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `message_id` varchar(255) NULL COMMENT 'Message id', + `feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike', + `reason_types` varchar(255) NULL COMMENT 'Feedback reason categories', + `remark` text NULL COMMENT 'Feedback remark', + `user_code` varchar(128) NULL COMMENT 'User code', `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -167,17 +180,20 @@ CREATE TABLE IF NOT EXISTS `plugin_hub` CREATE TABLE IF NOT EXISTS `prompt_manage` ( `id` int(11) NOT NULL AUTO_INCREMENT, - `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', - `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', - `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', - `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', - `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', - `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', - `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', - `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', - `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', - `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', - `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `chat_scene` varchar(100) DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) DEFAULT NULL COMMENT 'prompt name', + `prompt_code` varchar(256) DEFAULT NULL COMMENT 'prompt code', + `content` longtext COMMENT 'Prompt content', + `input_variables` varchar(1024) DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `response_schema` text DEFAULT NULL COMMENT 'Prompt response schema', + `model` varchar(128) DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) DEFAULT NULL COMMENT 'Prompt description', + `user_code` varchar(128) DEFAULT NULL COMMENT 'User code', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -186,6 +202,8 @@ CREATE TABLE IF NOT EXISTS `prompt_manage` KEY `gmt_created_idx` (`gmt_created`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + CREATE TABLE IF NOT EXISTS `gpts_conversations` ( `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', @@ -232,11 +250,15 @@ CREATE TABLE `gpts_messages` ( `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', `rounds` int(11) NOT NULL COMMENT 'dialogue turns', + `is_success` int(4) NULL DEFAULT 0 COMMENT 'agent message is success', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', `content` text COMMENT 'Content of the speech', `current_goal` text COMMENT 'The target corresponding to the current message', `context` text COMMENT 'Current conversation context', `review_info` text COMMENT 'Current conversation review info', `action_report` text COMMENT 'Current conversation action report', + `resource_info` text DEFAULT NULL COMMENT 'Current conversation resource info', `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', `created_at` datetime DEFAULT NULL COMMENT 'create time', `updated_at` datetime DEFAULT NULL COMMENT 'last update time', @@ -286,6 +308,7 @@ CREATE TABLE `dbgpt_serve_flow` ( `define_type` varchar(32) null comment 'Flow define type(json or python)', `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', + `variables` text DEFAULT NULL COMMENT 'Flow variables, JSON format', PRIMARY KEY (`id`), UNIQUE KEY `uk_uid` (`uid`), KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`), @@ -295,6 +318,51 @@ CREATE TABLE `dbgpt_serve_flow` ( KEY `ix_dbgpt_serve_flow_name` (`name`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `description` text DEFAULT NULL COMMENT 'Variable description', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + -- dbgpt.gpts_app definition CREATE TABLE `gpts_app` ( `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', @@ -309,6 +377,9 @@ CREATE TABLE `gpts_app` ( `created_at` datetime DEFAULT NULL COMMENT 'create time', `updated_at` datetime DEFAULT NULL COMMENT 'last update time', `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', + `published` varchar(64) DEFAULT 'false' COMMENT 'Has it been published?', + `param_need` text DEFAULT NULL COMMENT 'Parameter information supported by the application', + `admins` text DEFAULT NULL COMMENT 'administrator', PRIMARY KEY (`id`), UNIQUE KEY `uk_gpts_app` (`app_name`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; @@ -363,6 +434,38 @@ CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; +-- dbgpt.recommend_question definition +CREATE TABLE `recommend_question` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', + `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time', + `app_code` varchar(255) DEFAULT NULL COMMENT 'Current AI assistant code', + `question` text DEFAULT NULL COMMENT 'question', + `user_code` int(11) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NOT NULL COMMENT 'system app code', + `valid` varchar(10) DEFAULT 'true' COMMENT 'is it effective,true/false', + `chat_mode` varchar(255) DEFAULT NULL COMMENT 'Conversation scene mode,chat_knowledge...', + `params` text DEFAULT NULL COMMENT 'question param', + `is_hot_question` varchar(10) DEFAULT 'false' COMMENT 'Is it a popular recommendation question?', + PRIMARY KEY (`id`), + KEY `idx_rec_q_app_code` (`app_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT="AI application related recommendation issues"; + +-- dbgpt.user_recent_apps definition +CREATE TABLE `user_recent_apps` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', + `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time', + `app_code` varchar(255) DEFAULT NULL COMMENT 'AI assistant code', + `last_accessed` timestamp NULL DEFAULT NULL COMMENT 'User recent usage time', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + PRIMARY KEY (`id`), + KEY `idx_user_r_app_code` (`app_code`), + KEY `idx_last_accessed` (`last_accessed`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User recently used apps' + CREATE DATABASE IF NOT EXISTS EXAMPLE_1; diff --git a/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql index 3b814a114..e4fa9fee8 100644 --- a/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql +++ b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql @@ -37,6 +37,8 @@ ALTER TABLE chat_feed_back ADD COLUMN `reason_types` varchar(255) NULL COMMENT ALTER TABLE chat_feed_back ADD COLUMN `user_code` varchar(128) NULL COMMENT 'User code'; ALTER TABLE chat_feed_back ADD COLUMN `remark` text NULL COMMENT 'Feedback remark'; +-- dbgpt_serve_flow +ALTER TABLE dbgpt_serve_flow ADD COLUMN `variables` text DEFAULT NULL COMMENT 'Flow variables, JSON format'; -- dbgpt.recommend_question definition CREATE TABLE `recommend_question` ( @@ -52,7 +54,7 @@ CREATE TABLE `recommend_question` ( `params` text DEFAULT NULL COMMENT 'question param', `is_hot_question` varchar(10) DEFAULT 'false' COMMENT 'Is it a popular recommendation question?', PRIMARY KEY (`id`), - KEY `idx_app_code` (`app_code`) + KEY `idx_rec_q_app_code` (`app_code`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT="AI application related recommendation issues"; -- dbgpt.user_recent_apps definition @@ -65,7 +67,52 @@ CREATE TABLE `user_recent_apps` ( `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', PRIMARY KEY (`id`), - KEY `idx_app_code` (`app_code`), + KEY `idx_user_r_app_code` (`app_code`), KEY `idx_last_accessed` (`last_accessed`), KEY `idx_user_code` (`user_code`) -) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User recently used apps' \ No newline at end of file +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User recently used apps'; + +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `description` text DEFAULT NULL COMMENT 'Variable description', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + diff --git a/assets/schema/upgrade/v0_6_0/v0.6.0.sql b/assets/schema/upgrade/v0_6_0/v0.5.10.sql similarity index 83% rename from assets/schema/upgrade/v0_6_0/v0.6.0.sql rename to assets/schema/upgrade/v0_6_0/v0.5.10.sql index 78a7d88ec..a70d8e643 100644 --- a/assets/schema/upgrade/v0_6_0/v0.6.0.sql +++ b/assets/schema/upgrade/v0_6_0/v0.5.10.sql @@ -1,4 +1,4 @@ --- Full SQL of v0.5.9, please not modify this file(It must be same as the file in the release package) +-- Full SQL of v0.5.10, please not modify this file(It must be same as the file in the release package) CREATE DATABASE IF NOT EXISTS dbgpt; @@ -16,11 +16,10 @@ CREATE TABLE IF NOT EXISTS `knowledge_space` `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', `name` varchar(100) NOT NULL COMMENT 'knowledge space name', `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `domain_type` varchar(50) NOT NULL COMMENT 'domain type', `desc` varchar(500) NOT NULL COMMENT 'description', `owner` varchar(100) DEFAULT NULL COMMENT 'owner', `context` TEXT DEFAULT NULL COMMENT 'context argument', - `user_id` varchar(255) DEFAULT NULL COMMENT 'knowledge space owner', - `user_ids` TEXT DEFAULT NULL COMMENT 'knowledge space members', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', PRIMARY KEY (`id`), @@ -32,14 +31,12 @@ CREATE TABLE IF NOT EXISTS `knowledge_document` `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', `doc_name` varchar(100) NOT NULL COMMENT 'document path name', `doc_type` varchar(50) NOT NULL COMMENT 'doc type', - `doc_token` varchar(100) NOT NULL COMMENT 'doc token', `space` varchar(50) NOT NULL COMMENT 'knowledge space', `chunk_size` int NOT NULL COMMENT 'chunk size', `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `result` TEXT NULL COMMENT 'knowledge content', - `questions` TEXT NULL COMMENT 'document related questions', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', `summary` LONGTEXT NULL COMMENT 'knowledge summary', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', @@ -55,7 +52,6 @@ CREATE TABLE IF NOT EXISTS `document_chunk` `doc_type` varchar(50) NOT NULL COMMENT 'doc type', `document_id` int NOT NULL COMMENT 'document parent id', `content` longtext NOT NULL COMMENT 'chunk content', - `questions` text NULL COMMENT 'chunk related questions', `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -77,9 +73,6 @@ CREATE TABLE IF NOT EXISTS `connect_config` `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', `comment` text COMMENT 'db comment', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', - `user_name` varchar(255) DEFAULT NULL COMMENT 'user name', - `user_id` varchar(255) DEFAULT NULL COMMENT 'user id', - PRIMARY KEY (`id`), UNIQUE KEY `uk_db` (`db_name`), KEY `idx_q_db_type` (`db_type`) @@ -94,7 +87,6 @@ CREATE TABLE IF NOT EXISTS `chat_history` `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', - `app_code` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'App unique code', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -125,11 +117,6 @@ CREATE TABLE IF NOT EXISTS `chat_feed_back` `question` longtext DEFAULT NULL COMMENT 'User question', `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', - `message_id` varchar(255) NULL COMMENT 'Message id', - `feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike', - `reason_types` varchar(255) NULL COMMENT 'Feedback reason categories', - `remark` text NULL COMMENT 'Feedback remark', - `user_code` varchar(128) NULL COMMENT 'User code', `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -179,20 +166,17 @@ CREATE TABLE IF NOT EXISTS `plugin_hub` CREATE TABLE IF NOT EXISTS `prompt_manage` ( `id` int(11) NOT NULL AUTO_INCREMENT, - `chat_scene` varchar(100) DEFAULT NULL COMMENT 'Chat scene', - `sub_chat_scene` varchar(100) DEFAULT NULL COMMENT 'Sub chat scene', - `prompt_type` varchar(100) DEFAULT NULL COMMENT 'Prompt type: common or private', - `prompt_name` varchar(256) DEFAULT NULL COMMENT 'prompt name', - `prompt_code` varchar(256) DEFAULT NULL COMMENT 'prompt code', - `content` longtext COMMENT 'Prompt content', - `input_variables` varchar(1024) DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', - `response_schema` text DEFAULT NULL COMMENT 'Prompt response schema', - `model` varchar(128) DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', - `prompt_language` varchar(32) DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', - `prompt_format` varchar(32) DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', - `prompt_desc` varchar(512) DEFAULT NULL COMMENT 'Prompt description', - `user_code` varchar(128) DEFAULT NULL COMMENT 'User code', - `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', @@ -247,15 +231,11 @@ CREATE TABLE `gpts_messages` ( `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', `rounds` int(11) NOT NULL COMMENT 'dialogue turns', - `is_success` int(4) NULL DEFAULT 0 COMMENT 'agent message is success', - `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', - `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', `content` text COMMENT 'Content of the speech', `current_goal` text COMMENT 'The target corresponding to the current message', `context` text COMMENT 'Current conversation context', `review_info` text COMMENT 'Current conversation review info', `action_report` text COMMENT 'Current conversation action report', - `resource_info` text DEFAULT NULL COMMENT 'Current conversation resource info', `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', `created_at` datetime DEFAULT NULL COMMENT 'create time', `updated_at` datetime DEFAULT NULL COMMENT 'last update time', @@ -328,9 +308,6 @@ CREATE TABLE `gpts_app` ( `created_at` datetime DEFAULT NULL COMMENT 'create time', `updated_at` datetime DEFAULT NULL COMMENT 'last update time', `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', - `published` varchar(64) DEFAULT 'false' COMMENT 'Has it been published?', - `param_need` text DEFAULT NULL COMMENT 'Parameter information supported by the application', - `admins` text DEFAULT NULL COMMENT 'administrator', PRIMARY KEY (`id`), UNIQUE KEY `uk_gpts_app` (`app_name`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; @@ -364,37 +341,26 @@ CREATE TABLE `gpts_app_detail` ( UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; --- dbgpt.recommend_question definition -CREATE TABLE `recommend_question` ( - `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', - `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', - `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time', - `app_code` varchar(255) DEFAULT NULL COMMENT 'Current AI assistant code', - `question` text DEFAULT NULL COMMENT 'question', - `user_code` int(11) NOT NULL COMMENT 'user code', - `sys_code` varchar(255) NOT NULL COMMENT 'system app code', - `valid` varchar(10) DEFAULT 'true' COMMENT 'is it effective,true/false', - `chat_mode` varchar(255) DEFAULT NULL COMMENT 'Conversation scene mode,chat_knowledge...', - `params` text DEFAULT NULL COMMENT 'question param', - `is_hot_question` varchar(10) DEFAULT 'false' COMMENT 'Is it a popular recommendation question?', - PRIMARY KEY (`id`), - KEY `idx_app_code` (`app_code`) -) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT="AI application related recommendation issues"; ---dbgpt.user_recent_apps definition -CREATE TABLE `user_recent_apps` ( - `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', - `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', - `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time', - `app_code` varchar(255) DEFAULT NULL COMMENT 'AI assistant code', - `last_accessed` timestamp NULL DEFAULT NULL COMMENT 'User recent usage time', - `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', - `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', +-- For deploy model cluster of DB-GPT(StorageModelRegistry) +CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `model_name` varchar(128) NOT NULL COMMENT 'Model name', + `host` varchar(128) NOT NULL COMMENT 'Host of the model', + `port` int(11) NOT NULL COMMENT 'Port of the model', + `weight` float DEFAULT 1.0 COMMENT 'Weight of the model', + `check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model', + `healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy', + `enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled', + `prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance', + `last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', PRIMARY KEY (`id`), - KEY `idx_app_code` (`app_code`), - KEY `idx_last_accessed` (`last_accessed`), - KEY `idx_user_code` (`user_code`) -) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User recently used apps' + UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; CREATE diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 3fdc927c9..cd7936a60 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -303,6 +303,7 @@ class Config(metaclass=Singleton): ) # global dbgpt api key self.API_KEYS = os.getenv("API_KEYS", None) + self.ENCRYPT_KEY = os.getenv("ENCRYPT_KEY", "your_secret_key") # Non-streaming scene retries self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int( @@ -320,6 +321,17 @@ class Config(metaclass=Singleton): os.getenv("USE_NEW_WEB_UI", "True").lower() == "true" ) + # file server configuration + # The host of the current file server, if None, get the host automatically + self.FILE_SERVER_HOST = os.getenv("FILE_SERVER_HOST") + self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv( + "FILE_SERVER_LOCAL_STORAGE_PATH" + ) + # multi-instance flag + self.WEBSERVER_MULTI_INSTANCE = ( + os.getenv("MULTI_INSTANCE", "False").lower() == "true" + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 3ef08d4bc..a8a0f24d1 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -52,17 +52,17 @@ def initialize_components( param, system_app, embedding_model_name, embedding_model_path ) _initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path) - _initialize_model_cache(system_app) + _initialize_model_cache(system_app, param.port) _initialize_awel(system_app, param) # Initialize resource manager of agent _initialize_resource_manager(system_app) _initialize_agent(system_app) _initialize_openapi(system_app) # Register serve apps - register_serve_apps(system_app, CFG) + register_serve_apps(system_app, CFG, param.port) -def _initialize_model_cache(system_app: SystemApp): +def _initialize_model_cache(system_app: SystemApp, port: int): from dbgpt.storage.cache import initialize_cache if not CFG.MODEL_CACHE_ENABLE: @@ -72,6 +72,8 @@ def _initialize_model_cache(system_app: SystemApp): storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk" max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR + if CFG.WEBSERVER_MULTI_INSTANCE: + persist_dir = f"{persist_dir}_{port}" initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index e3f414cb0..911da686a 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -1,5 +1,6 @@ """Import all models to make sure they are registered with SQLAlchemy. """ + from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity @@ -10,7 +11,9 @@ from dbgpt.serve.agent.app.recommend_question.recommend_question import ( ) from dbgpt.serve.agent.hub.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.hub.db.plugin_hub_db import PluginHubEntity +from dbgpt.serve.file.models.models import ServeEntity as FileServeEntity from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity +from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity from dbgpt.storage.chat_history.chat_history_db import ( @@ -20,6 +23,7 @@ from dbgpt.storage.chat_history.chat_history_db import ( _MODELS = [ PluginHubEntity, + FileServeEntity, MyPluginEntity, PromptManageEntity, KnowledgeSpaceEntity, @@ -32,4 +36,5 @@ _MODELS = [ ModelInstanceEntity, FlowServeEntity, RecommendQuestionEntity, + FlowVariableEntity, ] diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 65620bff9..b88d1d215 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -2,11 +2,13 @@ from dbgpt._private.config import Config from dbgpt.component import SystemApp -def register_serve_apps(system_app: SystemApp, cfg: Config): +def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int): """Register serve apps""" system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) if cfg.API_KEYS: system_app.config.set("dbgpt.app.global.api_keys", cfg.API_KEYS) + if cfg.ENCRYPT_KEY: + system_app.config.set("dbgpt.app.global.encrypt_key", cfg.ENCRYPT_KEY) # ################################ Prompt Serve Register Begin ###################################### from dbgpt.serve.prompt.serve import ( @@ -45,6 +47,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(FlowServe) + # ################################ AWEL Flow Serve Register End ######################################## + # ################################ Rag Serve Register Begin ###################################### from dbgpt.serve.rag.serve import ( @@ -55,6 +59,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(RagServe) + # ################################ Rag Serve Register End ######################################## + # ################################ Datasource Serve Register Begin ###################################### from dbgpt.serve.datasource.serve import ( @@ -64,7 +70,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(DatasourceServe) - # ################################ AWEL Flow Serve Register End ######################################## + + # ################################ Datasource Serve Register End ######################################## # ################################ Chat Feedback Serve Register End ######################################## from dbgpt.serve.feedback.serve import ( @@ -75,3 +82,32 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve feedback system_app.register(FeedbackServe) # ################################ Chat Feedback Register End ######################################## + + # ################################ File Serve Register Begin ###################################### + + from dbgpt.configs.model_config import FILE_SERVER_LOCAL_STORAGE_PATH + from dbgpt.serve.file.serve import ( + SERVE_CONFIG_KEY_PREFIX as FILE_SERVE_CONFIG_KEY_PREFIX, + ) + from dbgpt.serve.file.serve import Serve as FileServe + + local_storage_path = ( + cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH + ) + if cfg.WEBSERVER_MULTI_INSTANCE: + local_storage_path = f"{local_storage_path}_{webserver_port}" + # Set config + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path + ) + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_port", webserver_port + ) + if cfg.FILE_SERVER_HOST: + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_host", cfg.FILE_SERVER_HOST + ) + # Register serve app + system_app.register(FileServe) + + # ################################ File Serve Register End ######################################## diff --git a/dbgpt/component.py b/dbgpt/component.py index bb7a7a9e4..da3c5e753 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -89,6 +89,8 @@ class ComponentType(str, Enum): CONNECTOR_MANAGER = "dbgpt_connector_manager" AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" + VARIABLES_PROVIDER = "dbgpt_variables_provider" + FILE_STORAGE_CLIENT = "dbgpt_file_storage_client" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/configs/__init__.py b/dbgpt/configs/__init__.py index 66177d8ee..fb0098cd4 100644 --- a/dbgpt/configs/__init__.py +++ b/dbgpt/configs/__init__.py @@ -20,3 +20,11 @@ del load_dotenv TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE = "knowledge_factory_domain_type" TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE = "knowledge_chat_domain_type" DOMAIN_TYPE_FINANCIAL_REPORT = "FinancialReport" + +VARIABLES_SCOPE_GLOBAL = "global" +VARIABLES_SCOPE_APP = "app" +VARIABLES_SCOPE_AGENT = "agent" +VARIABLES_SCOPE_FLOW = "flow" +VARIABLES_SCOPE_DATASOURCE = "datasource" +VARIABLES_SCOPE_FLOW_PRIVATE = "flow_priv" +VARIABLES_SCOPE_AGENT_PRIVATE = "agent_priv" diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 4d02a2730..e4abac3e7 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -14,6 +14,7 @@ DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATA_DIR = os.path.join(PILOT_PATH, "data") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") +FILE_SERVER_LOCAL_STORAGE_PATH = os.path.join(DATA_DIR, "file_server") _DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel") # Global language setting LOCALES_DIR = os.path.join(ROOT_PATH, "i18n/locales") diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index ddcfd52bc..ffe6a7b0e 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -5,13 +5,26 @@ DAG is the core component of AWEL, it is used to define the relationship between import asyncio import contextvars +import dataclasses import logging import threading import uuid from abc import ABC, abstractmethod from collections import deque from concurrent.futures import Executor -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Union, + cast, +) from dbgpt.component import SystemApp @@ -23,6 +36,9 @@ logger = logging.getLogger(__name__) DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]] +if TYPE_CHECKING: + from ...interface.variables import VariablesProvider + def _is_async_context(): try: @@ -128,6 +144,8 @@ class DAGVar: # The executor for current DAG, this is used run some sync tasks in async DAG _executor: Optional[Executor] = None + _variables_provider: Optional["VariablesProvider"] = None + @classmethod def enter_dag(cls, dag) -> None: """Enter a DAG context. @@ -221,6 +239,24 @@ class DAGVar: """ cls._executor = executor + @classmethod + def get_variables_provider(cls) -> Optional["VariablesProvider"]: + """Get the current variables provider. + + Returns: + Optional[VariablesProvider]: The current variables provider + """ + return cls._variables_provider + + @classmethod + def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None: + """Set the current variables provider. + + Args: + variables_provider (VariablesProvider): The variables provider to set + """ + cls._variables_provider = variables_provider + class DAGLifecycle: """The lifecycle of DAG.""" @@ -455,6 +491,100 @@ def _build_task_key(task_name: str, key: str) -> str: return f"{task_name}___$$$$$$___{key}" +@dataclasses.dataclass +class _DAGVariablesItem: + """The DAG variables item. + + It is a private class, just used for internal. + """ + + key: str + name: str + label: str + value: Any + category: Literal["common", "secret"] = "common" + scope: str = "global" + value_type: Optional[str] = None + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + description: Optional[str] = None + + +@dataclasses.dataclass +class DAGVariables: + """The DAG variables.""" + + items: List[_DAGVariablesItem] = dataclasses.field(default_factory=list) + _cached_provider: Optional["VariablesProvider"] = None + _lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + def merge(self, dag_variables: "DAGVariables") -> "DAGVariables": + """Merge the DAG variables. + + Args: + dag_variables (DAGVariables): The DAG variables to merge + """ + + def _build_key(item: _DAGVariablesItem): + key = "_".join([item.key, item.name, item.scope]) + if item.scope_key: + key += f"_{item.scope_key}" + if item.sys_code: + key += f"_{item.sys_code}" + if item.user_name: + key += f"_{item.user_name}" + return key + + new_items = [] + exist_vars = set() + for item in self.items: + new_items.append(item) + exist_vars.add(_build_key(item)) + for item in dag_variables.items: + key = _build_key(item) + if key not in exist_vars: + new_items.append(item) + return DAGVariables( + items=new_items, + _cached_provider=self._cached_provider or dag_variables._cached_provider, + ) + + def to_provider(self) -> "VariablesProvider": + """Convert the DAG variables to variables provider. + + Returns: + VariablesProvider: The variables provider + """ + if not self._cached_provider: + from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + ) + + with self._lock: + # Create a new provider safely + provider = StorageVariablesProvider() + for item in self.items: + storage_vars = StorageVariables( + key=item.key, + name=item.name, + label=item.label, + value=item.value, + category=item.category, + scope=item.scope, + value_type=item.value_type, + scope_key=item.scope_key, + sys_code=item.sys_code, + user_name=item.user_name, + description=item.description, + ) + provider.save(storage_vars) + self._cached_provider = provider + + return self._cached_provider + + class DAGContext: """The context of current DAG, created when the DAG is running. @@ -468,6 +598,7 @@ class DAGContext: event_loop_task_id: int, streaming_call: bool = False, node_name_to_ids: Optional[Dict[str, str]] = None, + dag_variables: Optional[DAGVariables] = None, ) -> None: """Initialize a DAGContext. @@ -477,6 +608,7 @@ class DAGContext: streaming_call (bool, optional): Whether the current DAG is streaming call. Defaults to False. node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node + dag_variables (Optional[DAGVariables], optional): The DAG variables. """ if not node_name_to_ids: node_name_to_ids = {} @@ -486,6 +618,7 @@ class DAGContext: self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs self._node_name_to_ids: Dict[str, str] = node_name_to_ids self._event_loop_task_id = event_loop_task_id + self._dag_variables = dag_variables @property def _task_outputs(self) -> Dict[str, TaskContext]: @@ -619,6 +752,7 @@ class DAG: resource_group: Optional[ResourceGroup] = None, tags: Optional[Dict[str, str]] = None, description: Optional[str] = None, + default_dag_variables: Optional[DAGVariables] = None, ) -> None: """Initialize a DAG.""" self._dag_id = dag_id @@ -632,6 +766,7 @@ class DAG: self._resource_group: Optional[ResourceGroup] = resource_group self._lock = asyncio.Lock() self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {} + self._default_dag_variables = default_dag_variables def _append_node(self, node: DAGNode) -> None: if node.node_id in self.node_map: diff --git a/dbgpt/core/awel/dag/dag_manager.py b/dbgpt/core/awel/dag/dag_manager.py index 91a49a166..15a07254a 100644 --- a/dbgpt/core/awel/dag/dag_manager.py +++ b/dbgpt/core/awel/dag/dag_manager.py @@ -197,7 +197,7 @@ class DAGManager(BaseComponent): return self._dag_metadata_map.get(dag.dag_id) -def _parse_metadata(dag: DAG): +def _parse_metadata(dag: DAG) -> DAGMetadata: from ..util.chat_util import _is_sse_output metadata = DAGMetadata() diff --git a/dbgpt/core/awel/flow/__init__.py b/dbgpt/core/awel/flow/__init__.py index 5a173565f..80db5b7e6 100644 --- a/dbgpt/core/awel/flow/__init__.py +++ b/dbgpt/core/awel/flow/__init__.py @@ -7,6 +7,7 @@ from ..util.parameter_util import ( # noqa: F401 BaseDynamicOptions, FunctionDynamicOptions, OptionValue, + VariablesDynamicOptions, ) from .base import ( # noqa: F401 IOField, @@ -35,4 +36,5 @@ __ALL__ = [ "IOField", "BaseDynamicOptions", "FunctionDynamicOptions", + "VariablesDynamicOptions", ] diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index cd8d9f78e..314cb2171 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -6,7 +6,7 @@ import inspect from abc import ABC from datetime import date, datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast +from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast from dbgpt._private.pydantic import ( BaseModel, @@ -15,10 +15,17 @@ from dbgpt._private.pydantic import ( model_to_dict, model_validator, ) -from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue +from dbgpt.component import SystemApp +from dbgpt.core.awel.util.parameter_util import ( + BaseDynamicOptions, + OptionValue, + RefreshOptionRequest, +) from dbgpt.core.interface.serialization import Serializable +from dbgpt.util.executor_utils import DefaultExecutorFactory, blocking_func_to_async from .exceptions import FlowMetadataException, FlowParameterMetadataException +from .ui import UIComponent _TYPE_REGISTRY: Dict[str, Type] = {} @@ -136,6 +143,7 @@ _OPERATOR_CATEGORY_DETAIL = { "agent": _CategoryDetail("Agent", "The agent operator"), "rag": _CategoryDetail("RAG", "The RAG operator"), "experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"), + "example": _CategoryDetail("Example", "Example operator"), } @@ -151,6 +159,7 @@ class OperatorCategory(str, Enum): AGENT = "agent" RAG = "rag" EXPERIMENTAL = "experimental" + EXAMPLE = "example" def label(self) -> str: """Get the label of the category.""" @@ -193,6 +202,7 @@ _RESOURCE_CATEGORY_DETAIL = { "embeddings": _CategoryDetail("Embeddings", "The embeddings resource"), "rag": _CategoryDetail("RAG", "The resource"), "vector_store": _CategoryDetail("Vector Store", "The vector store resource"), + "example": _CategoryDetail("Example", "The example resource"), } @@ -209,6 +219,7 @@ class ResourceCategory(str, Enum): EMBEDDINGS = "embeddings" RAG = "rag" VECTOR_STORE = "vector_store" + EXAMPLE = "example" def label(self) -> str: """Get the label of the category.""" @@ -343,6 +354,9 @@ class Parameter(TypeMetadata, Serializable): alias: Optional[List[str]] = Field( None, description="The alias of the parameter(Compatible with old version)" ) + ui: Optional[UIComponent] = Field( + None, description="The UI component of the parameter" + ) @model_validator(mode="before") @classmethod @@ -366,27 +380,40 @@ class Parameter(TypeMetadata, Serializable): return values @classmethod - def _covert_to_real_type(cls, type_cls: str, v: Any): + def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any: if type_cls and v is not None: + typed_value: Any = v try: # Try to convert the value to the type. if type_cls == "builtins.str": - return str(v) + typed_value = str(v) elif type_cls == "builtins.int": - return int(v) + typed_value = int(v) elif type_cls == "builtins.float": - return float(v) + typed_value = float(v) elif type_cls == "builtins.bool": if str(v).lower() in ["false", "0", "", "no", "off"]: return False - return bool(v) + typed_value = bool(v) + return typed_value except ValueError: raise ValidationError(f"Value '{v}' is not valid for type {type_cls}") return v def get_typed_value(self) -> Any: - """Get the typed value.""" - return self._covert_to_real_type(self.type_cls, self.value) + """Get the typed value. + + Returns: + Any: The typed value. VariablesPlaceHolder if the value is a variable + string. Otherwise, the real type value. + """ + from ...interface.variables import VariablesPlaceHolder, is_variable_string + + is_variables = is_variable_string(self.value) if self.value else False + if is_variables and self.value is not None and isinstance(self.value, str): + return VariablesPlaceHolder(self.name, self.value) + else: + return self._covert_to_real_type(self.type_cls, self.value) def get_typed_default(self) -> Any: """Get the typed default.""" @@ -398,6 +425,7 @@ class Parameter(TypeMetadata, Serializable): label: str, name: str, type: Type, + is_list: bool = False, optional: bool = False, default: Optional[Union[DefaultParameterType, _MISSING_TYPE]] = _MISSING_VALUE, placeholder: Optional[DefaultParameterType] = None, @@ -405,6 +433,7 @@ class Parameter(TypeMetadata, Serializable): options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None, resource_type: ResourceType = ResourceType.INSTANCE, alias: Optional[List[str]] = None, + ui: Optional[UIComponent] = None, ): """Build the parameter from the type.""" type_name = type.__qualname__ @@ -419,6 +448,7 @@ class Parameter(TypeMetadata, Serializable): name=name, type_name=type_name, type_cls=type_cls, + is_list=is_list, category=category.value, resource_type=resource_type, optional=optional, @@ -427,6 +457,7 @@ class Parameter(TypeMetadata, Serializable): description=description or label, options=options, alias=alias, + ui=ui, ) @classmethod @@ -456,11 +487,12 @@ class Parameter(TypeMetadata, Serializable): description=data["description"], options=data["options"], value=data["value"], + ui=data.get("ui"), ) def to_dict(self) -> Dict: """Convert current metadata to json dict.""" - dict_value = model_to_dict(self, exclude={"options", "alias"}) + dict_value = model_to_dict(self, exclude={"options", "alias", "ui"}) if not self.options: dict_value["options"] = None elif isinstance(self.options, BaseDynamicOptions): @@ -468,6 +500,36 @@ class Parameter(TypeMetadata, Serializable): dict_value["options"] = [value.to_dict() for value in values] else: dict_value["options"] = [value.to_dict() for value in self.options] + + if self.ui: + dict_value["ui"] = self.ui.to_dict() + return dict_value + + async def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: + """Refresh the options of the parameter. + + Args: + request (RefreshOptionRequest): The request to refresh the options. + trigger (Literal["default", "http"], optional): The trigger type. + Defaults to "default". + system_app (Optional[SystemApp], optional): The system app. + + Returns: + Dict: The response. + """ + dict_value = self.to_dict() + if not self.options: + dict_value["options"] = None + elif isinstance(self.options, BaseDynamicOptions): + values = self.options.refresh(request, trigger, system_app) + dict_value["options"] = [value.to_dict() for value in values] + else: + dict_value["options"] = [value.to_dict() for value in self.options] return dict_value def get_dict_options(self) -> Optional[List[Dict]]: @@ -641,10 +703,10 @@ class BaseMetadata(BaseResource): ], ) - tags: Optional[List[str]] = Field( + tags: Optional[Dict[str, str]] = Field( default=None, description="The tags of the operator", - examples=[["llm", "openai", "gpt3"]], + examples=[{"order": "higher-order"}, {"order": "first-order"}], ) parameters: List[Parameter] = Field( @@ -754,6 +816,58 @@ class BaseMetadata(BaseResource): ] return dict_value + async def refresh( + self, + request: List[RefreshOptionRequest], + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: + """Refresh the metadata. + + Args: + request (List[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + executor = DefaultExecutorFactory.get_instance(system_app).create() + + name_to_request = {req.name: req for req in request} + parameter_requests = { + parameter.name: name_to_request.get(parameter.name) + for parameter in self.parameters + } + dict_value = model_to_dict(self, exclude={"parameters"}) + parameters = [] + for parameter in self.parameters: + parameter_dict = parameter.to_dict() + parameter_request = parameter_requests.get(parameter.name) + if not parameter.options: + options = None + elif isinstance(parameter.options, BaseDynamicOptions): + options_obj = parameter.options + if options_obj.support_async(system_app, parameter_request): + values = await options_obj.async_refresh( + parameter_request, trigger, system_app + ) + else: + values = await blocking_func_to_async( + executor, + options_obj.refresh, + parameter_request, + trigger, + system_app, + ) + options = [value.to_dict() for value in values] + else: + options = [value.to_dict() for value in self.options] + parameter_dict["options"] = options + parameters.append(parameter_dict) + + dict_value["parameters"] = parameters + + return dict_value + class ResourceMetadata(BaseMetadata, TypeMetadata): """The metadata of the resource.""" @@ -1033,9 +1147,58 @@ class FlowRegistry: """Get the registry item by the key.""" return self._registry.get(key) - def metadata_list(self): - """Get the metadata list.""" - return [item.metadata.to_dict() for item in self._registry.values()] + def metadata_list( + self, + tags: Optional[Dict[str, str]] = None, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + ) -> List[Dict]: + """Get the metadata list. + + TODO: Support the user and system code filter. + + Args: + tags (Optional[Dict[str, str]], optional): The tags. Defaults to None. + user_name (Optional[str], optional): The user name. Defaults to None. + sys_code (Optional[str], optional): The system code. Defaults to None. + + Returns: + List[Dict]: The metadata list. + """ + if not tags: + return [item.metadata.to_dict() for item in self._registry.values()] + else: + results = [] + for item in self._registry.values(): + node_tags = item.metadata.tags + is_match = True + if not node_tags or not isinstance(node_tags, dict): + continue + for k, v in tags.items(): + if node_tags.get(k) != v: + is_match = False + break + if is_match: + results.append(item.metadata.to_dict()) + return results + + async def refresh( + self, + key: str, + is_operator: bool, + request: List[RefreshOptionRequest], + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: + """Refresh the metadata.""" + if is_operator: + return await _get_operator_class(key).metadata.refresh( # type: ignore + request, trigger, system_app + ) + else: + return await _get_resource_class(key).metadata.refresh( + request, trigger, system_app + ) _OPERATOR_REGISTRY: FlowRegistry = FlowRegistry() diff --git a/dbgpt/core/awel/flow/exceptions.py b/dbgpt/core/awel/flow/exceptions.py index 0c3dc667d..68c02f8ac 100644 --- a/dbgpt/core/awel/flow/exceptions.py +++ b/dbgpt/core/awel/flow/exceptions.py @@ -44,3 +44,14 @@ class FlowDAGMetadataException(FlowMetadataException): def __init__(self, message: str, error_type="build_dag_metadata_error"): """Create a new FlowDAGMetadataException.""" super().__init__(message, error_type) + + +class FlowUIComponentException(FlowException): + """The exception for UI parameter failed.""" + + def __init__( + self, message: str, component_name: str, error_type="build_ui_component_error" + ): + """Create a new FlowUIParameterException.""" + new_message = f"{component_name}: {message}" + super().__init__(new_message, error_type) diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 4db9755b0..87b828971 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -4,7 +4,7 @@ import logging import uuid from contextlib import suppress from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast from typing_extensions import Annotated @@ -17,6 +17,7 @@ from dbgpt._private.pydantic import ( model_to_dict, model_validator, ) +from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE from dbgpt.core.awel.dag.base import DAG, DAGNode from dbgpt.core.awel.dag.dag_manager import DAGMetadata @@ -166,6 +167,143 @@ class FlowData(BaseModel): viewport: FlowPositionData = Field(..., description="Viewport of the flow") +class _VariablesRequestBase(BaseModel): + key: str = Field( + ..., + description="The key of the variable to create", + examples=["dbgpt.model.openai.api_key"], + ) + + label: str = Field( + ..., + description="The label of the variable to create", + examples=["My First OpenAI Key"], + ) + + description: Optional[str] = Field( + None, + description="The description of the variable to create", + examples=["Your OpenAI API key"], + ) + value_type: Literal["str", "int", "float", "bool"] = Field( + "str", + description="The type of the value of the variable to create", + examples=["str", "int", "float", "bool"], + ) + category: Literal["common", "secret"] = Field( + ..., + description="The category of the variable to create", + examples=["common"], + ) + scope: str = Field( + ..., + description="The scope of the variable to create", + examples=["global"], + ) + scope_key: Optional[str] = Field( + None, + description="The scope key of the variable to create", + examples=["dbgpt"], + ) + + +class VariablesRequest(_VariablesRequestBase): + """Variable request model. + + For creating a new variable in the DB-GPT. + """ + + name: str = Field( + ..., + description="The name of the variable to create", + examples=["my_first_openai_key"], + ) + value: Any = Field( + ..., description="The value of the variable to create", examples=["1234567890"] + ) + enabled: Optional[bool] = Field( + True, + description="Whether the variable is enabled", + examples=[True], + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") + + +class ParsedFlowVariables(BaseModel): + """Parsed variables for the flow.""" + + key: str = Field( + ..., + description="The key of the variable", + examples=["dbgpt.model.openai.api_key"], + ) + name: Optional[str] = Field( + None, + description="The name of the variable", + examples=["my_first_openai_key"], + ) + scope: str = Field( + ..., + description="The scope of the variable", + examples=["global"], + ) + scope_key: Optional[str] = Field( + None, + description="The scope key of the variable", + examples=["dbgpt"], + ) + sys_code: Optional[str] = Field(None, description="System code") + user_name: Optional[str] = Field(None, description="User name") + + +class FlowVariables(_VariablesRequestBase): + """Variables for the flow.""" + + name: Optional[str] = Field( + None, + description="The name of the variable", + examples=["my_first_openai_key"], + ) + value: Optional[Any] = Field( + None, description="The value of the variable", examples=["1234567890"] + ) + parsed_variables: Optional[ParsedFlowVariables] = Field( + None, description="The parsed variables, parsed from the value" + ) + + @model_validator(mode="before") + @classmethod + def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Pre fill the metadata.""" + if not isinstance(values, dict): + return values + if "parsed_variables" not in values: + parsed_variables = cls.parse_value_to_variables(values.get("value")) + if parsed_variables: + values["parsed_variables"] = parsed_variables + return values + + @classmethod + def parse_value_to_variables(cls, value: Any) -> Optional[ParsedFlowVariables]: + """Parse the value to variables. + + Args: + value (Any): The value to parse + + Returns: + Optional[ParsedFlowVariables]: The parsed variables, None if the value is + invalid + """ + from ...interface.variables import _is_variable_format, parse_variable + + if not value or not isinstance(value, str) or not _is_variable_format(value): + return None + + variable_dict = parse_variable(value) + return ParsedFlowVariables(**variable_dict) + + class State(str, Enum): """State of a flow panel.""" @@ -356,6 +494,12 @@ class FlowPanel(BaseModel): metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field( default=None, description="The metadata of the flow" ) + variables: Optional[List[FlowVariables]] = Field( + default=None, description="The variables of the flow" + ) + authors: Optional[List[str]] = Field( + default=None, description="The authors of the flow" + ) @model_validator(mode="before") @classmethod @@ -378,6 +522,21 @@ class FlowPanel(BaseModel): """Convert to dict.""" return model_to_dict(self, exclude={"flow_dag"}) + def get_variables_dict(self) -> List[Dict[str, Any]]: + """Get the variables dict.""" + if not self.variables: + return [] + return [v.dict() for v in self.variables] + + @classmethod + def parse_variables( + cls, variables: Optional[List[Dict[str, Any]]] = None + ) -> Optional[List[FlowVariables]]: + """Parse the variables.""" + if not variables: + return None + return [FlowVariables(**v) for v in variables] + class FlowFactory: """Flow factory.""" @@ -598,10 +757,36 @@ class FlowFactory: dag_id: Optional[str] = None, ) -> DAG: """Build the DAG.""" + from ..dag.base import DAGVariables, _DAGVariablesItem + formatted_name = flow_panel.name.replace(" ", "_") if not dag_id: dag_id = f"{self._dag_prefix}_{formatted_name}_{flow_panel.uid}" - with DAG(dag_id) as dag: + + default_dag_variables: Optional[DAGVariables] = None + if flow_panel.variables: + variables = [] + for v in flow_panel.variables: + scope_key = v.scope_key + if v.scope == VARIABLES_SCOPE_FLOW_PRIVATE and not scope_key: + scope_key = dag_id + variables.append( + _DAGVariablesItem( + key=v.key, + name=v.name, # type: ignore + label=v.label, + description=v.description, + value_type=v.value_type, + category=v.category, + scope=v.scope, + scope_key=scope_key, + value=v.value, + user_name=flow_panel.user_name, + sys_code=flow_panel.sys_code, + ) + ) + default_dag_variables = DAGVariables(items=variables) + with DAG(dag_id, default_dag_variables=default_dag_variables) as dag: for key, task in key_to_tasks.items(): if not task._node_id: task.set_node_id(dag._new_node_id()) diff --git a/dbgpt/core/awel/flow/tests/test_flow_variables.py b/dbgpt/core/awel/flow/tests/test_flow_variables.py new file mode 100644 index 000000000..3f1b04154 --- /dev/null +++ b/dbgpt/core/awel/flow/tests/test_flow_variables.py @@ -0,0 +1,294 @@ +import json +from typing import cast + +import pytest + +from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE +from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator +from dbgpt.core.awel.flow import ( + IOField, + OperatorCategory, + Parameter, + VariablesDynamicOptions, + ViewMetadata, + ui, +) +from dbgpt.core.awel.flow.flow_factory import ( + FlowData, + FlowFactory, + FlowPanel, + FlowVariables, +) + +from ...tests.conftest import variables_provider + + +class MyVariablesOperator(MapOperator[str, str]): + metadata = ViewMetadata( + label="My Test Variables Operator", + name="my_test_variables_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a variables option.", + parameters=[ + Parameter.build_from( + "OpenAI API Key", + "openai_api_key", + type=str, + placeholder="Please select the OpenAI API key", + description="The OpenAI API key to use.", + options=VariablesDynamicOptions(), + ui=ui.UIPasswordInput( + key="dbgpt.model.openai.api_key", + ), + ), + Parameter.build_from( + "Model", + "model", + type=str, + placeholder="Please select the model", + description="The model to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.model.openai.model", + ), + ), + Parameter.build_from( + "DAG Var 1", + "dag_var1", + type=str, + placeholder="Please select the DAG variable 1", + description="The DAG variable 1.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE + ), + ), + Parameter.build_from( + "DAG Var 2", + "dag_var2", + type=str, + placeholder="Please select the DAG variable 2", + description="The DAG variable 2.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.core.flow.params", scope=VARIABLES_SCOPE_FLOW_PRIVATE + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Model info", + "model", + str, + description="The model info.", + ), + ], + ) + + def __init__( + self, openai_api_key: str, model: str, dag_var1: str, dag_var2: str, **kwargs + ): + super().__init__(**kwargs) + self._openai_api_key = openai_api_key + self._model = model + self._dag_var1 = dag_var1 + self._dag_var2 = dag_var2 + + async def map(self, user_name: str) -> str: + dict_dict = { + "openai_api_key": self._openai_api_key, + "model": self._model, + "dag_var1": self._dag_var1, + "dag_var2": self._dag_var2, + } + json_data = json.dumps(dict_dict, ensure_ascii=False) + return "Your name is %s, and your model info is %s." % (user_name, json_data) + + +class EndOperator(MapOperator[str, str]): + metadata = ViewMetadata( + label="End Operator", + name="end_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that ends the flow.", + parameters=[], + inputs=[ + IOField.build_from( + "Input", + "input", + str, + description="The input to the end operator.", + ), + ], + outputs=[ + IOField.build_from( + "Output", + "output", + str, + description="The output of the end operator.", + ), + ], + ) + + async def map(self, input: str) -> str: + return f"End operator received input: {input}" + + +@pytest.fixture +def json_flow(): + operators = [MyVariablesOperator, EndOperator] + metadata_list = [operator.metadata.to_dict() for operator in operators] + node_names = {} + name_to_parameters_dict = { + "my_test_variables_operator": { + "openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}", + "model": "${dbgpt.model.openai.model:default_model@global}", + "dag_var1": "${dbgpt.core.flow.params:name1@%s}" + % VARIABLES_SCOPE_FLOW_PRIVATE, + "dag_var2": "${dbgpt.core.flow.params:name2@%s}" + % VARIABLES_SCOPE_FLOW_PRIVATE, + } + } + name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list} + ui_nodes = [] + for metadata in metadata_list: + type_name = metadata["type_name"] + name = metadata["name"] + id = metadata["id"] + if type_name in node_names: + raise ValueError(f"Duplicate node type name: {type_name}") + # Replace id to flow data id. + metadata["id"] = f"{id}_0" + parameters = metadata["parameters"] + parameters_dict = name_to_parameters_dict.get(name, {}) + for parameter in parameters: + parameter_name = parameter["name"] + if parameter_name in parameters_dict: + parameter["value"] = parameters_dict[parameter_name] + ui_nodes.append( + { + "width": 288, + "height": 352, + "id": metadata["id"], + "position": { + "x": -149.98120112708142, + "y": 666.9468497341901, + "zoom": 0.0, + }, + "type": "customNode", + "position_absolute": { + "x": -149.98120112708142, + "y": 666.9468497341901, + "zoom": 0.0, + }, + "data": metadata, + } + ) + + ui_edges = [] + source_id = name_to_metadata_dict["my_test_variables_operator"]["id"] + target_id = name_to_metadata_dict["end_operator"]["id"] + ui_edges.append( + { + "source": source_id, + "target": target_id, + "source_order": 0, + "target_order": 0, + "id": f"{source_id}|{target_id}", + "source_handle": f"{source_id}|outputs|0", + "target_handle": f"{target_id}|inputs|0", + "type": "buttonedge", + } + ) + return { + "nodes": ui_nodes, + "edges": ui_edges, + "viewport": { + "x": 509.2191773722104, + "y": -66.11286175905718, + "zoom": 1.252741002590748, + }, + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "variables_provider", + [ + ( + { + "vars": { + "openai_api_key": { + "key": "${dbgpt.model.openai.api_key:my_key@global}", + "value": "my_openai_api_key", + "value_type": "str", + "category": "secret", + }, + "model": { + "key": "${dbgpt.model.openai.model:default_model@global}", + "value": "GPT-4o", + "value_type": "str", + }, + } + } + ), + ], + indirect=["variables_provider"], +) +async def test_build_flow(json_flow, variables_provider): + DAGVar.set_variables_provider(variables_provider) + flow_data = FlowData(**json_flow) + variables = [ + FlowVariables( + key="dbgpt.core.flow.params", + name="name1", + label="Name 1", + value="value1", + value_type="str", + category="common", + scope=VARIABLES_SCOPE_FLOW_PRIVATE, + # scope_key="my_test_flow", + ), + FlowVariables( + key="dbgpt.core.flow.params", + name="name2", + label="Name 2", + value="value2", + value_type="str", + category="common", + scope=VARIABLES_SCOPE_FLOW_PRIVATE, + # scope_key="my_test_flow", + ), + ] + flow_panel = FlowPanel( + label="My Test Flow", + name="my_test_flow", + flow_data=flow_data, + state="deployed", + variables=variables, + ) + factory = FlowFactory() + dag = factory.build(flow_panel) + + leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0]) + result = await leaf_node.call("Alice") + expected_dict = { + "openai_api_key": "my_openai_api_key", + "model": "GPT-4o", + "dag_var1": "value1", + "dag_var2": "value2", + } + expected_dict_str = json.dumps(expected_dict, ensure_ascii=False) + assert ( + result + == f"End operator received input: Your name is Alice, and your model info is " + f"{expected_dict_str}." + ) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py new file mode 100644 index 000000000..928755a20 --- /dev/null +++ b/dbgpt/core/awel/flow/ui.py @@ -0,0 +1,456 @@ +"""UI components for AWEL flow.""" + +from typing import Any, Dict, List, Literal, Optional, Union + +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict +from dbgpt.core.interface.serialization import Serializable + +from .exceptions import FlowUIComponentException + +_UI_TYPE = Literal[ + "select", + "cascader", + "checkbox", + "radio", + "date_picker", + "input", + "text_area", + "auto_complete", + "slider", + "time_picker", + "tree_select", + "upload", + "variables", + "password", + "code_editor", +] + + +class RefreshableMixin(BaseModel): + """Refreshable mixin.""" + + refresh: Optional[bool] = Field( + False, + description="Whether to enable the refresh", + ) + refresh_depends: Optional[List[str]] = Field( + None, + description="The dependencies of the refresh", + ) + + +class StatusMixin(BaseModel): + """Status mixin.""" + + status: Optional[Literal["error", "warning"]] = Field( + None, + description="Status of the input", + ) + + +class PanelEditorMixin(BaseModel): + """Edit the content in the panel.""" + + class Editor(BaseModel): + """Editor configuration.""" + + width: Optional[int] = Field( + None, + description="The width of the panel", + ) + height: Optional[int] = Field( + None, + description="The height of the panel", + ) + + editor: Optional[Editor] = Field( + default_factory=lambda: PanelEditorMixin.Editor(width=800, height=400), + description="The editor configuration", + ) + + +class UIComponent(RefreshableMixin, Serializable, BaseModel): + """UI component.""" + + class UIAttribute(BaseModel): + """Base UI attribute.""" + + disabled: bool = Field( + False, + description="Whether the component is disabled", + ) + + ui_type: _UI_TYPE = Field(..., description="UI component type") + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter. + + Raises: + FlowUIParameterException: If the parameter is invalid. + """ + + def _check_options(self, options: Dict[str, Any]): + """Check options.""" + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + + def to_dict(self) -> Dict: + """Convert current metadata to json dict.""" + return model_to_dict(self) + + +class UISelect(UIComponent): + """Select component.""" + + class UIAttribute(StatusMixin, UIComponent.UIAttribute): + """Select attribute.""" + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + mode: Optional[Literal["tags"]] = Field( + None, + description="The mode of the select", + ) + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + + ui_type: Literal["select"] = Field("select", frozen=True) + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UICascader(UIComponent): + """Cascader component.""" + + class UIAttribute(StatusMixin, UIComponent.UIAttribute): + """Cascader attribute.""" + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + + ui_type: Literal["cascader"] = Field("cascader", frozen=True) + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + options = parameter_dict.get("options") + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + first_level = options[0] + if "children" not in first_level: + raise FlowUIComponentException( + "children is required in options", self.ui_type + ) + + +class UICheckbox(UIComponent): + """Checkbox component.""" + + ui_type: Literal["checkbox"] = Field("checkbox", frozen=True) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UIRadio(UICheckbox): + """Radio component.""" + + ui_type: Literal["radio"] = Field("radio", frozen=True) # type: ignore + + +class UIDatePicker(UIComponent): + """Date picker component.""" + + class UIAttribute(StatusMixin, UIComponent.UIAttribute): + """Date picker attribute.""" + + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + + ui_type: Literal["date_picker"] = Field("date_picker", frozen=True) + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + +class UIInput(UIComponent): + """Input component.""" + + class UIAttribute(StatusMixin, UIComponent.UIAttribute): + """Input attribute.""" + + prefix: Optional[str] = Field( + None, + description="The prefix, icon or text", + examples=["$", "icon:UserOutlined"], + ) + suffix: Optional[str] = Field( + None, + description="The suffix, icon or text", + examples=["$", "icon:SearchOutlined"], + ) + show_count: Optional[bool] = Field( + None, + description="Whether to show count", + ) + max_length: Optional[int] = Field( + None, + description="The maximum length of the input", + ) + + ui_type: Literal["input"] = Field("input", frozen=True) + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + +class UITextArea(PanelEditorMixin, UIInput): + """Text area component.""" + + class UIAttribute(UIInput.UIAttribute): + """Text area attribute.""" + + class AutoSize(BaseModel): + """Auto size configuration.""" + + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + + auto_size: Optional[Union[bool, AutoSize]] = Field( + None, + description="Whether the height of the textarea automatically adjusts " + "based on the content", + ) + + ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + +class UIAutoComplete(UIInput): + """Auto complete component.""" + + ui_type: Literal["auto_complete"] = Field( # type: ignore + "auto_complete", frozen=True + ) + + +class UISlider(UIComponent): + """Slider component.""" + + class UIAttribute(UIComponent.UIAttribute): + """Slider attribute.""" + + min: Optional[int | float] = Field( + None, + description="The minimum value", + ) + max: Optional[int | float] = Field( + None, + description="The maximum value", + ) + step: Optional[int | float] = Field( + None, + description="The step of the slider", + ) + + ui_type: Literal["slider"] = Field("slider", frozen=True) + + show_input: bool = Field( + False, description="Whether to display the value in a input component" + ) + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + +class UITimePicker(UIComponent): + """Time picker component.""" + + class UIAttribute(StatusMixin, UIComponent.UIAttribute): + """Time picker attribute.""" + + format: Optional[str] = Field( + None, + description="The format of the time", + examples=["HH:mm:ss", "HH:mm"], + ) + hour_step: Optional[int] = Field( + None, + description="The step of the hour input", + ) + minute_step: Optional[int] = Field( + None, + description="The step of the minute input", + ) + second_step: Optional[int] = Field( + None, + description="The step of the second input", + ) + + ui_type: Literal["time_picker"] = Field("time_picker", frozen=True) + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + +class UITreeSelect(UICascader): + """Tree select component.""" + + ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) # type: ignore + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + options = parameter_dict.get("options") + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + first_level = options[0] + if "children" not in first_level: + raise FlowUIComponentException( + "children is required in options", self.ui_type + ) + + +class UIUpload(UIComponent): + """Upload component.""" + + class UIAttribute(UIComponent.UIAttribute): + """Upload attribute.""" + + max_count: Optional[int] = Field( + None, + description="The maximum number of files that can be uploaded", + ) + + ui_type: Literal["upload"] = Field("upload", frozen=True) + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + max_file_size: Optional[int] = Field( + None, + description="The maximum size of the file, in bytes", + ) + + file_types: Optional[List[str]] = Field( + None, + description="The file types that can be accepted", + examples=[[".png", ".jpg"]], + ) + up_event: Optional[Literal["after_select", "button_click"]] = Field( + None, + description="The event that triggers the upload", + ) + drag: bool = Field( + False, + description="Whether to support drag and drop upload", + ) + action: Optional[str] = Field( + "/api/v2/serve/file/files/dbgpt", + description="The URL for the file upload(default bucket is 'dbgpt')", + ) + + +class UIVariablesInput(UIInput): + """Variables input component.""" + + ui_type: Literal["variable"] = Field("variables", frozen=True) # type: ignore + key: str = Field(..., description="The key of the variable") + key_type: Literal["common", "secret"] = Field( + "common", + description="The type of the key", + ) + scope: str = Field("global", description="The scope of the variables") + scope_key: Optional[str] = Field( + None, + description="The key of the scope", + ) + refresh: Optional[bool] = Field( + True, + description="Whether to enable the refresh", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UIPasswordInput(UIVariablesInput): + """Password input component.""" + + ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore + + key_type: Literal["secret"] = Field( + "secret", + description="The type of the key", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UICodeEditor(UITextArea): + """Code editor component.""" + + ui_type: Literal["code_editor"] = Field("code_editor", frozen=True) # type: ignore + + language: Optional[str] = Field( + "python", + description="The language of the code", + ) + + +class DefaultUITextArea(UITextArea): + """Default text area component.""" + + attr: Optional[UITextArea.UIAttribute] = Field( + default_factory=lambda: UITextArea.UIAttribute( + auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40) + ), + description="The attributes of the component", + ) diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index aafa11f90..da82d2856 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -2,10 +2,12 @@ import asyncio import functools +import logging from abc import ABC, ABCMeta, abstractmethod from contextvars import ContextVar from types import FunctionType from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Dict, @@ -18,6 +20,7 @@ from typing import ( ) from dbgpt.component import ComponentType, SystemApp +from dbgpt.configs import VARIABLES_SCOPE_FLOW_PRIVATE from dbgpt.util.executor_utils import ( AsyncToSyncIterator, BlockingFunction, @@ -26,9 +29,14 @@ from dbgpt.util.executor_utils import ( ) from dbgpt.util.tracer import root_tracer -from ..dag.base import DAG, DAGContext, DAGNode, DAGVar +from ..dag.base import DAG, DAGContext, DAGNode, DAGVar, DAGVariables from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data +if TYPE_CHECKING: + from ...interface.variables import VariablesProvider + +logger = logging.getLogger(__name__) + F = TypeVar("F", bound=FunctionType) CALL_DATA = Union[Dict[str, Any], Any] @@ -51,6 +59,7 @@ class WorkflowRunner(ABC, Generic[T]): call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, exist_dag_ctx: Optional[DAGContext] = None, + dag_variables: Optional[DAGVariables] = None, ) -> DAGContext: """Execute the workflow starting from a given operator. @@ -60,6 +69,7 @@ class WorkflowRunner(ABC, Generic[T]): streaming_call (bool): Whether the call is a streaming call. exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. + dag_variables (DAGVariables): The DAG variables. Returns: DAGContext: The context after executing the workflow, containing the final state and data. @@ -92,6 +102,9 @@ class BaseOperatorMeta(ABCMeta): kwargs.get("system_app") or DAGVar.get_current_system_app() ) executor = kwargs.get("executor") or DAGVar.get_executor() + variables_provider = ( + kwargs.get("variables_provider") or DAGVar.get_variables_provider() + ) if not executor: if system_app: executor = system_app.get_component( @@ -102,14 +115,24 @@ class BaseOperatorMeta(ABCMeta): else: executor = DefaultExecutorFactory().create() DAGVar.set_executor(executor) + if not variables_provider: + from ...interface.variables import VariablesProvider + + if system_app: + variables_provider = system_app.get_component( + ComponentType.VARIABLES_PROVIDER, + VariablesProvider, + default_component=None, + ) + else: + from ...interface.variables import StorageVariablesProvider + + variables_provider = StorageVariablesProvider() + DAGVar.set_variables_provider(variables_provider) if not task_id and dag: task_id = dag._new_node_id() runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner - # print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}") - # for arg in sig_cache.parameters: - # if arg not in kwargs: - # kwargs[arg] = default_args[arg] if not kwargs.get("dag"): kwargs["dag"] = dag if not kwargs.get("task_id"): @@ -120,6 +143,8 @@ class BaseOperatorMeta(ABCMeta): kwargs["system_app"] = system_app if not kwargs.get("executor"): kwargs["executor"] = executor + if not kwargs.get("variables_provider"): + kwargs["variables_provider"] = variables_provider real_obj = func(self, *args, **kwargs) return real_obj @@ -150,6 +175,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): dag: Optional[DAG] = None, runner: Optional[WorkflowRunner] = None, can_skip_in_branch: bool = True, + variables_provider: Optional["VariablesProvider"] = None, **kwargs, ) -> None: """Create a BaseOperator with an optional workflow runner. @@ -171,6 +197,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): self._runner: WorkflowRunner = runner self._dag_ctx: Optional[DAGContext] = None self._can_skip_in_branch = can_skip_in_branch + self._variables_provider = variables_provider @property def current_dag_context(self) -> DAGContext: @@ -199,6 +226,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): if not task_log_id: raise ValueError(f"The task log ID can't be empty, current node {self}") CURRENT_DAG_CONTEXT.set(dag_ctx) + # Resolve variables + await self._resolve_variables(dag_ctx) return await self._do_run(dag_ctx) @abstractmethod @@ -217,6 +246,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): self, call_data: Optional[CALL_DATA] = EMPTY_DATA, dag_ctx: Optional[DAGContext] = None, + dag_variables: Optional[DAGVariables] = None, ) -> OUT: """Execute the node and return the output. @@ -226,6 +256,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): call_data (CALL_DATA): The data pass to root operator node. dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. + dag_variables (DAGVariables): The DAG variables passed to current DAG. Returns: OUT: The output of the node after execution. """ @@ -233,13 +264,15 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): call_data = {"data": call_data} with root_tracer.start_span("dbgpt.awel.operator.call"): out_ctx = await self._runner.execute_workflow( - self, call_data, exist_dag_ctx=dag_ctx + self, call_data, exist_dag_ctx=dag_ctx, dag_variables=dag_variables ) return out_ctx.current_task_context.task_output.output def _blocking_call( self, call_data: Optional[CALL_DATA] = EMPTY_DATA, + dag_ctx: Optional[DAGContext] = None, + dag_variables: Optional[DAGVariables] = None, loop: Optional[asyncio.BaseEventLoop] = None, ) -> OUT: """Execute the node and return the output. @@ -249,7 +282,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): Args: call_data (CALL_DATA): The data pass to root operator node. - + dag_ctx (DAGContext): The context of the DAG when this node is run, + Defaults to None. + dag_variables (DAGVariables): The DAG variables passed to current DAG. + loop (asyncio.BaseEventLoop): The event loop to run the operator. Returns: OUT: The output of the node after execution. """ @@ -258,12 +294,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): if not loop: loop = get_or_create_event_loop() loop = cast(asyncio.BaseEventLoop, loop) - return loop.run_until_complete(self.call(call_data)) + return loop.run_until_complete(self.call(call_data, dag_ctx, dag_variables)) async def call_stream( self, call_data: Optional[CALL_DATA] = EMPTY_DATA, dag_ctx: Optional[DAGContext] = None, + dag_variables: Optional[DAGVariables] = None, ) -> AsyncIterator[OUT]: """Execute the node and return the output as a stream. @@ -273,7 +310,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): call_data (CALL_DATA): The data pass to root operator node. dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. - + dag_variables (DAGVariables): The DAG variables passed to current DAG. Returns: AsyncIterator[OUT]: An asynchronous iterator over the output stream. """ @@ -281,7 +318,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): call_data = {"data": call_data} with root_tracer.start_span("dbgpt.awel.operator.call_stream"): out_ctx = await self._runner.execute_workflow( - self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx + self, + call_data, + streaming_call=True, + exist_dag_ctx=dag_ctx, + dag_variables=dag_variables, ) task_output = out_ctx.current_task_context.task_output @@ -302,6 +343,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): def _blocking_call_stream( self, call_data: Optional[CALL_DATA] = EMPTY_DATA, + dag_ctx: Optional[DAGContext] = None, + dag_variables: Optional[DAGVariables] = None, loop: Optional[asyncio.BaseEventLoop] = None, ) -> Iterator[OUT]: """Execute the node and return the output as a stream. @@ -311,7 +354,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): Args: call_data (CALL_DATA): The data pass to root operator node. - + dag_ctx (DAGContext): The context of the DAG when this node is run, + Defaults to None. + dag_variables (DAGVariables): The DAG variables passed to current DAG. + loop (asyncio.BaseEventLoop): The event loop to run the operator. Returns: Iterator[OUT]: An iterator over the output stream. """ @@ -319,7 +365,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): if not loop: loop = get_or_create_event_loop() - return AsyncToSyncIterator(self.call_stream(call_data), loop) + return AsyncToSyncIterator( + self.call_stream(call_data, dag_ctx, dag_variables), loop + ) async def blocking_func_to_async( self, func: BlockingFunction, *args, **kwargs @@ -347,6 +395,79 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): """Check if the operator can be skipped in the branch.""" return self._can_skip_in_branch + async def _resolve_variables(self, dag_ctx: DAGContext): + """Resolve variables in the operator. + + Some attributes of the operator may be VariablesPlaceHolder, which need to be + resolved before the operator is executed. + + Args: + dag_ctx (DAGContext): The context of the DAG when this node is run. + """ + from ...interface.variables import VariablesIdentifier, VariablesPlaceHolder + + if not self._variables_provider: + return + + if dag_ctx._dag_variables: + # Resolve variables in DAG context + resolve_tasks = [] + resolve_items = [] + for item in dag_ctx._dag_variables.items: + # TODO: Resolve variables just once? + if isinstance(item.value, VariablesPlaceHolder): + resolve_tasks.append( + self.blocking_func_to_async( + item.value.parse, self._variables_provider + ) + ) + resolve_items.append(item) + resolved_values = await asyncio.gather(*resolve_tasks) + for item, rv in zip(resolve_items, resolved_values): + item.value = rv + dag_provider: Optional["VariablesProvider"] = None + if dag_ctx._dag_variables: + dag_provider = dag_ctx._dag_variables.to_provider() + + # TODO: Resolve variables parallel + for attr, value in self.__dict__.items(): + # Handle all attributes that are VariablesPlaceHolder + if isinstance(value, VariablesPlaceHolder): + resolved_value: Any = None + default_identifier_map = None + id_key = VariablesIdentifier.from_str_identifier(value.full_key) + if ( + id_key.scope == VARIABLES_SCOPE_FLOW_PRIVATE + and id_key.scope_key is None + and self.dag + ): + default_identifier_map = {"scope_key": self.dag.dag_id} + + if dag_provider: + # First try to resolve the variable with the DAG variables + resolved_value = await self.blocking_func_to_async( + value.parse, + dag_provider, + ignore_not_found_error=True, + default_identifier_map=default_identifier_map, + ) + if resolved_value is None: + resolved_value = await self.blocking_func_to_async( + value.parse, + self._variables_provider, + default_identifier_map=default_identifier_map, + ) + logger.debug( + f"Resolve variable {attr} with value {resolved_value} for " + f"{self} from system variables" + ) + else: + logger.debug( + f"Resolve variable {attr} with value {resolved_value} for " + f"{self} from DAG variables" + ) + setattr(self, attr, resolved_value) + def initialize_runner(runner: WorkflowRunner): """Initialize the default runner.""" diff --git a/dbgpt/core/awel/operators/common_operator.py b/dbgpt/core/awel/operators/common_operator.py index fc2dc098b..f8bc25370 100644 --- a/dbgpt/core/awel/operators/common_operator.py +++ b/dbgpt/core/awel/operators/common_operator.py @@ -334,7 +334,8 @@ class InputOperator(BaseOperator, Generic[OUT]): async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context task_output = await self._input_source.read(curr_task_ctx) - curr_task_ctx.set_task_output(task_output) + new_task_output: TaskOutput[OUT] = await task_output.map(self.map) + curr_task_ctx.set_task_output(new_task_output) return task_output @classmethod @@ -342,6 +343,10 @@ class InputOperator(BaseOperator, Generic[OUT]): """Create a dummy InputOperator with a given input value.""" return cls(input_source=InputSource.from_data(dummy_data), **kwargs) + async def map(self, input_data: OUT) -> OUT: + """Map the input data to a new value.""" + return input_data + class TriggerOperator(InputOperator[OUT], Generic[OUT]): """Operator node that triggers the DAG to run.""" diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index f968be9af..9fdf33b83 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, cast from dbgpt.component import SystemApp from dbgpt.util.tracer import root_tracer -from ..dag.base import DAGContext, DAGVar +from ..dag.base import DAGContext, DAGVar, DAGVariables from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner from ..operators.common_operator import BranchOperator from ..task.base import SKIP_DATA, TaskContext, TaskState @@ -46,6 +46,7 @@ class DefaultWorkflowRunner(WorkflowRunner): call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, exist_dag_ctx: Optional[DAGContext] = None, + dag_variables: Optional[DAGVariables] = None, ) -> DAGContext: """Execute the workflow. @@ -57,6 +58,7 @@ class DefaultWorkflowRunner(WorkflowRunner): Defaults to False. exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context. Defaults to None. + dag_variables (Optional[DAGVariables], optional): The DAG variables. """ # Save node output # dag = node.dag @@ -71,12 +73,19 @@ class DefaultWorkflowRunner(WorkflowRunner): node_outputs = exist_dag_ctx._node_to_outputs share_data = exist_dag_ctx._share_data event_loop_task_id = exist_dag_ctx._event_loop_task_id + if dag_variables and exist_dag_ctx._dag_variables: + # Merge dag variables, prefer the `dag_variables` in the parameter + dag_variables = dag_variables.merge(exist_dag_ctx._dag_variables) + if node.dag and not dag_variables and node.dag._default_dag_variables: + # Use default dag variables if not set + dag_variables = node.dag._default_dag_variables dag_ctx = DAGContext( event_loop_task_id=event_loop_task_id, node_to_outputs=node_outputs, share_data=share_data, streaming_call=streaming_call, node_name_to_ids=job_manager._node_name_to_ids, + dag_variables=dag_variables, ) # if node.dag: # self._running_dag_ctx[node.dag.dag_id] = dag_ctx diff --git a/dbgpt/core/awel/tests/conftest.py b/dbgpt/core/awel/tests/conftest.py index d68ddcfc8..607783028 100644 --- a/dbgpt/core/awel/tests/conftest.py +++ b/dbgpt/core/awel/tests/conftest.py @@ -1,17 +1,15 @@ -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from typing import AsyncIterator, List import pytest import pytest_asyncio -from .. import ( - DAGContext, - DefaultWorkflowRunner, - InputOperator, - SimpleInputSource, - TaskState, - WorkflowRunner, +from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, ) +from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource from ..task.task_impl import _is_async_iterator @@ -102,3 +100,32 @@ async def stream_input_nodes(request): param["is_stream"] = True async with _create_input_node(**param) as input_nodes: yield input_nodes + + +@asynccontextmanager +async def _create_variables(**kwargs): + vp = StorageVariablesProvider() + vars = kwargs.get("vars") + if vars and isinstance(vars, dict): + for param_key, param_var in vars.items(): + key = param_var.get("key") + value = param_var.get("value") + value_type = param_var.get("value_type") + category = param_var.get("category", "common") + id = VariablesIdentifier.from_str_identifier(key) + vp.save( + StorageVariables.from_identifier( + id, value, value_type, label="", category=category + ) + ) + else: + raise ValueError("vars is required.") + + yield vp + + +@pytest_asyncio.fixture +async def variables_provider(request): + param = getattr(request, "param", {}) + async with _create_variables(**param) as vp: + yield vp diff --git a/dbgpt/core/awel/tests/test_dag_variables.py b/dbgpt/core/awel/tests/test_dag_variables.py new file mode 100644 index 000000000..8bdb29143 --- /dev/null +++ b/dbgpt/core/awel/tests/test_dag_variables.py @@ -0,0 +1,111 @@ +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio + +from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, + VariablesPlaceHolder, +) +from .. import DAG, DAGVar, InputOperator, MapOperator, SimpleInputSource + + +class VariablesOperator(MapOperator[str, str]): + def __init__(self, int_var: int, str_var: str, secret: str, **kwargs): + super().__init__(**kwargs) + self._int_var = int_var + self._str_var = str_var + self._secret = secret + + async def map(self, x: str) -> str: + return ( + f"x: {x}, int_var: {self._int_var}, str_var: {self._str_var}, " + f"secret: {self._secret}" + ) + + +@pytest.fixture +def default_dag(): + with DAG("test_dag") as dag: + input_node = InputOperator(input_source=SimpleInputSource.from_callable()) + map_node = MapOperator(lambda x: x * 2) + input_node >> map_node + return dag + + +@asynccontextmanager +async def _create_variables(**kwargs): + variables_provider = StorageVariablesProvider() + DAGVar.set_variables_provider(variables_provider) + + vars = kwargs.get("vars") + variables = {} + if vars and isinstance(vars, dict): + for param_key, param_var in vars.items(): + key = param_var.get("key") + value = param_var.get("value") + value_type = param_var.get("value_type") + category = param_var.get("category", "common") + id = VariablesIdentifier.from_str_identifier(key) + variables_provider.save( + StorageVariables.from_identifier( + id, value, value_type, label="", category=category + ) + ) + variables[param_key] = VariablesPlaceHolder(param_key, key) + else: + raise ValueError("vars is required.") + + with DAG("simple_dag") as dag: + map_node = VariablesOperator(**variables) + yield map_node + + +@pytest_asyncio.fixture +async def variables_node(request): + param = getattr(request, "param", {}) + async with _create_variables(**param) as node: + yield node + + +@pytest.mark.asyncio +async def test_default_dag(default_dag: DAG): + leaf_node = default_dag.leaf_nodes[0] + res = await leaf_node.call(2) + assert res == 4 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "variables_node", + [ + ( + { + "vars": { + "int_var": { + "key": "${int_key:my_int_var@global}", + "value": 0, + "value_type": "int", + }, + "str_var": { + "key": "${str_key:my_str_var@global}", + "value": "1", + "value_type": "str", + }, + "secret": { + "key": "${secret_key:my_secret_var@global}", + "value": "2131sdsdf", + "value_type": "str", + "category": "secret", + }, + } + } + ), + ], + indirect=["variables_node"], +) +async def test_input_nodes(variables_node: VariablesOperator): + res = await variables_node.call("test") + assert res == "x: test, int_var: 0, str_var: 1, secret: 2131sdsdf" diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 22e025c13..8f0298297 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -87,7 +87,9 @@ class HttpTriggerMetadata(TriggerMetadata): path: str = Field(..., description="The path of the trigger") methods: List[str] = Field(..., description="The methods of the trigger") - + trigger_mode: str = Field( + default="command", description="The mode of the trigger, command or chat" + ) trigger_type: Optional[str] = Field( default="http", description="The type of the trigger" ) @@ -477,7 +479,9 @@ class HttpTrigger(Trigger): )(dynamic_route_function) logger.info(f"Mount http trigger success, path: {path}") - return HttpTriggerMetadata(path=path, methods=self._methods) + return HttpTriggerMetadata( + path=path, methods=self._methods, trigger_mode=self._trigger_mode() + ) def mount_to_app( self, app: "FastAPI", global_prefix: Optional[str] = None @@ -512,7 +516,9 @@ class HttpTrigger(Trigger): app.openapi_schema = None app.middleware_stack = None logger.info(f"Mount http trigger success, path: {path}") - return HttpTriggerMetadata(path=path, methods=self._methods) + return HttpTriggerMetadata( + path=path, methods=self._methods, trigger_mode=self._trigger_mode() + ) def remove_from_app( self, app: "FastAPI", global_prefix: Optional[str] = None @@ -537,6 +543,36 @@ class HttpTrigger(Trigger): # TODO, remove with path and methods del app_router.routes[i] + def _trigger_mode(self) -> str: + if ( + self._req_body + and isinstance(self._req_body, type) + and issubclass(self._req_body, CommonLLMHttpRequestBody) + ): + return "chat" + return "command" + + async def map(self, input_data: Any) -> Any: + """Map the input data. + + Do some transformation for the input data. + + Args: + input_data (Any): The input data from caller. + + Returns: + Any: The mapped data. + """ + if not self._req_body or not input_data: + return await super().map(input_data) + if ( + isinstance(self._req_body, type) + and issubclass(self._req_body, BaseModel) + and isinstance(input_data, dict) + ): + return self._req_body(**input_data) + return await super().map(input_data) + def _create_route_func(self): from inspect import Parameter, Signature from typing import get_type_hints diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index defd99a3b..a492169c5 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -2,20 +2,60 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Literal, Optional from dbgpt._private.pydantic import BaseModel, Field, model_validator +from dbgpt.component import SystemApp from dbgpt.core.interface.serialization import Serializable _DEFAULT_DYNAMIC_REGISTRY = {} +class RefreshOptionDependency(BaseModel): + """The refresh dependency.""" + + name: str = Field(..., description="The name of the refresh dependency") + value: Optional[Any] = Field( + None, description="The value of the refresh dependency" + ) + has_value: bool = Field( + False, description="Whether the refresh dependency has value" + ) + + +class RefreshOptionRequest(BaseModel): + """The refresh option request.""" + + name: str = Field(..., description="The name of parameter to refresh") + depends: Optional[List[RefreshOptionDependency]] = Field( + None, description="The depends of the refresh config" + ) + variables_key: Optional[str] = Field( + None, description="The variables key to refresh" + ) + variables_scope: Optional[str] = Field( + None, description="The variables scope to refresh" + ) + variables_scope_key: Optional[str] = Field( + None, description="The variables scope key to refresh" + ) + variables_sys_code: Optional[str] = Field( + None, description="The system code to refresh" + ) + variables_user_name: Optional[str] = Field( + None, description="The user name to refresh" + ) + + class OptionValue(Serializable, BaseModel): """The option value of the parameter.""" label: str = Field(..., description="The label of the option") name: str = Field(..., description="The name of the option") value: Any = Field(..., description="The value of the option") + children: Optional[List["OptionValue"]] = Field( + None, description="The children of the option" + ) def to_dict(self) -> Dict: """Convert current metadata to json dict.""" @@ -25,24 +65,80 @@ class OptionValue(Serializable, BaseModel): class BaseDynamicOptions(Serializable, BaseModel, ABC): """The base dynamic options.""" - @abstractmethod + def support_async( + self, + system_app: Optional[SystemApp] = None, + request: Optional[RefreshOptionRequest] = None, + ) -> bool: + """Whether the dynamic options support async. + + Args: + system_app (Optional[SystemApp]): The system app + request (Optional[RefreshOptionRequest]): The refresh request + + Returns: + bool: Whether the dynamic options support async + """ + return False + def option_values(self) -> List[OptionValue]: """Return the option values of the parameter.""" + return self.refresh(None) + + @abstractmethod + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options. + + Args: + request (Optional[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + + async def async_refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options async. + + Args: + request (Optional[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + raise NotImplementedError("The dynamic options does not support async.") class FunctionDynamicOptions(BaseDynamicOptions): """The function dynamic options.""" - func: Callable[[], List[OptionValue]] = Field( + func: Callable[..., List[OptionValue]] = Field( ..., description="The function to generate the dynamic options" ) func_id: str = Field( ..., description="The unique id of the function to generate the dynamic options" ) - def option_values(self) -> List[OptionValue]: - """Return the option values of the parameter.""" - return self.func() + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options.""" + if not request or not request.depends: + return self.func() + kwargs = {dep.name: dep.value for dep in request.depends if dep.has_value} + return self.func(**kwargs) @model_validator(mode="before") @classmethod @@ -65,6 +161,109 @@ class FunctionDynamicOptions(BaseDynamicOptions): return {"func_id": self.func_id} +class VariablesDynamicOptions(BaseDynamicOptions): + """The variables dynamic options.""" + + def support_async( + self, + system_app: Optional[SystemApp] = None, + request: Optional[RefreshOptionRequest] = None, + ) -> bool: + """Whether the dynamic options support async.""" + if not system_app or not request or not request.variables_key: + return False + + from ...interface.variables import BuiltinVariablesProvider + + provider: BuiltinVariablesProvider = system_app.get_component( + request.variables_key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False + return provider.support_async() + + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options.""" + if ( + trigger == "default" + or not request + or not request.variables_key + or not request.variables_scope + ): + # Only refresh when trigger is http and request is not None + return [] + if not system_app: + raise ValueError("The system app is required when refresh the variables.") + from ...interface.variables import VariablesProvider + + vp: VariablesProvider = VariablesProvider.get_instance(system_app) + variables = vp.get_variables( + key=request.variables_key, + scope=request.variables_scope, + scope_key=request.variables_scope_key, + sys_code=request.variables_sys_code, + user_name=request.variables_user_name, + ) + options = [] + for var in variables: + options.append( + OptionValue( + label=var.label, + name=var.name, + value=var.value, + ) + ) + return options + + async def async_refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options async.""" + if ( + trigger == "default" + or not request + or not request.variables_key + or not request.variables_scope + ): + return [] + if not system_app: + raise ValueError("The system app is required when refresh the variables.") + from ...interface.variables import VariablesProvider + + vp: VariablesProvider = VariablesProvider.get_instance(system_app) + variables = await vp.async_get_variables( + key=request.variables_key, + scope=request.variables_scope, + scope_key=request.variables_scope_key, + sys_code=request.variables_sys_code, + user_name=request.variables_user_name, + ) + options = [] + for var in variables: + options.append( + OptionValue( + label=var.label, + name=var.name, + value=var.value, + ) + ) + return options + + def to_dict(self) -> Dict: + """Convert current metadata to json dict.""" + return {"key": self.key} + + def _generate_unique_id(func: Callable) -> str: if func.__name__ == "": func_id = f"lambda_{inspect.getfile(func)}_{inspect.getsourcelines(func)}" diff --git a/dbgpt/core/interface/file.py b/dbgpt/core/interface/file.py new file mode 100644 index 000000000..ea1ddb2f3 --- /dev/null +++ b/dbgpt/core/interface/file.py @@ -0,0 +1,834 @@ +"""File storage interface.""" + +import dataclasses +import hashlib +import io +import logging +import os +import uuid +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Any, BinaryIO, Dict, List, Optional, Tuple +from urllib.parse import parse_qs, urlencode, urlparse + +import requests + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.tracer import root_tracer, trace + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageError, + StorageInterface, + StorageItem, +) + +logger = logging.getLogger(__name__) +_SCHEMA = "dbgpt-fs" + + +@dataclasses.dataclass +class FileMetadataIdentifier(ResourceIdentifier): + """File metadata identifier.""" + + file_id: str + bucket: str + + def to_dict(self) -> Dict: + """Convert the identifier to a dictionary.""" + return {"file_id": self.file_id, "bucket": self.bucket} + + @property + def str_identifier(self) -> str: + """Get the string identifier. + + Returns: + str: The string identifier + """ + return f"{self.bucket}/{self.file_id}" + + +@dataclasses.dataclass +class FileMetadata(StorageItem): + """File metadata for storage.""" + + file_id: str + bucket: str + file_name: str + file_size: int + storage_type: str + storage_path: str + uri: str + custom_metadata: Dict[str, Any] + file_hash: str + user_name: Optional[str] = None + sys_code: Optional[str] = None + _identifier: FileMetadataIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = FileMetadataIdentifier( + file_id=self.file_id, bucket=self.bucket + ) + custom_metadata = self.custom_metadata or {} + if not self.user_name: + self.user_name = custom_metadata.get("user_name") + if not self.sys_code: + self.sys_code = custom_metadata.get("sys_code") + + @property + def identifier(self) -> ResourceIdentifier: + """Get the resource identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge the metadata with another item.""" + if not isinstance(other, FileMetadata): + raise StorageError("Cannot merge different types of items") + self._from_object(other) + + def to_dict(self) -> Dict: + """Convert the metadata to a dictionary.""" + return { + "file_id": self.file_id, + "bucket": self.bucket, + "file_name": self.file_name, + "file_size": self.file_size, + "storage_type": self.storage_type, + "storage_path": self.storage_path, + "uri": self.uri, + "custom_metadata": self.custom_metadata, + "file_hash": self.file_hash, + } + + def _from_object(self, obj: "FileMetadata") -> None: + self.file_id = obj.file_id + self.bucket = obj.bucket + self.file_name = obj.file_name + self.file_size = obj.file_size + self.storage_type = obj.storage_type + self.storage_path = obj.storage_path + self.uri = obj.uri + self.custom_metadata = obj.custom_metadata + self.file_hash = obj.file_hash + self._identifier = obj._identifier + + +class FileStorageURI: + """File storage URI.""" + + def __init__( + self, + storage_type: str, + bucket: str, + file_id: str, + version: Optional[str] = None, + custom_params: Optional[Dict[str, Any]] = None, + ): + """Initialize the file storage URI.""" + self.scheme = _SCHEMA + self.storage_type = storage_type + self.bucket = bucket + self.file_id = file_id + self.version = version + self.custom_params = custom_params or {} + + @classmethod + def is_local_file(cls, uri: str) -> bool: + """Check if the URI is local.""" + parsed = urlparse(uri) + if not parsed.scheme or parsed.scheme == "file": + return True + return False + + @classmethod + def parse(cls, uri: str) -> "FileStorageURI": + """Parse the URI string.""" + parsed = urlparse(uri) + if parsed.scheme != _SCHEMA: + raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'") + path_parts = parsed.path.strip("/").split("/") + if len(path_parts) < 2: + raise ValueError("Invalid URI path. Must contain bucket and file ID") + storage_type = parsed.netloc + bucket = path_parts[0] + file_id = path_parts[1] + version = path_parts[2] if len(path_parts) > 2 else None + custom_params = parse_qs(parsed.query) + return cls(storage_type, bucket, file_id, version, custom_params) + + def __str__(self) -> str: + """Get the string representation of the URI.""" + base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}" + if self.version: + base_uri += f"/{self.version}" + if self.custom_params: + query_string = urlencode(self.custom_params, doseq=True) + base_uri += f"?{query_string}" + return base_uri + + +class StorageBackend(ABC): + """Storage backend interface.""" + + storage_type: str = "__base__" + + @abstractmethod + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the storage backend. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + + @abstractmethod + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + + @abstractmethod + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + + @property + @abstractmethod + def save_chunk_size(self) -> int: + """Get the save chunk size. + + Returns: + int: The save chunk size + """ + + +class LocalFileStorage(StorageBackend): + """Local file storage backend.""" + + storage_type: str = "local" + + def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024): + """Initialize the local file storage backend.""" + self.base_path = base_path + self._save_chunk_size = save_chunk_size + os.makedirs(self.base_path, exist_ok=True) + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the local storage backend.""" + bucket_path = os.path.join(self.base_path, bucket) + os.makedirs(bucket_path, exist_ok=True) + file_path = os.path.join(bucket_path, file_id) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + return file_path + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + return open(file_path, "rb") # noqa: SIM115 + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + + +class FileStorageSystem: + """File storage system.""" + + def __init__( + self, + storage_backends: Dict[str, StorageBackend], + metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None, + check_hash: bool = True, + ): + """Initialize the file storage system.""" + metadata_storage = metadata_storage or InMemoryStorage() + self.storage_backends = storage_backends + self.metadata_storage = metadata_storage + self.check_hash = check_hash + self._save_chunk_size = min( + backend.save_chunk_size for backend in storage_backends.values() + ) + + def _calculate_file_hash(self, file_data: BinaryIO) -> str: + """Calculate the MD5 hash of the file data.""" + if not self.check_hash: + return "-1" + hasher = hashlib.md5() + file_data.seek(0) + while chunk := file_data.read(self._save_chunk_size): + hasher.update(chunk) + file_data.seek(0) + return hasher.hexdigest() + + @trace("file_storage_system.save_file") + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage backend.""" + file_id = str(uuid.uuid4()) + backend = self.storage_backends.get(storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {storage_type}") + + with root_tracer.start_span( + "file_storage_system.save_file.backend_save", + metadata={ + "bucket": bucket, + "file_id": file_id, + "file_name": file_name, + "storage_type": storage_type, + }, + ): + storage_path = backend.save(bucket, file_id, file_data) + file_data.seek(0, 2) # Move to the end of the file + file_size = file_data.tell() # Get the file size + file_data.seek(0) # Reset file pointer + + # filter None value + custom_metadata = ( + {k: v for k, v in custom_metadata.items() if v is not None} + if custom_metadata + else {} + ) + + with root_tracer.start_span( + "file_storage_system.save_file.calculate_hash", + ): + file_hash = self._calculate_file_hash(file_data) + uri = FileStorageURI( + storage_type, bucket, file_id, custom_params=custom_metadata + ) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name=file_name, + file_size=file_size, + storage_type=storage_type, + storage_path=storage_path, + uri=str(uri), + custom_metadata=custom_metadata, + file_hash=file_hash, + ) + + self.metadata_storage.save(metadata) + return str(uri) + + @trace("file_storage_system.get_file") + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage backend.""" + if FileStorageURI.is_local_file(uri): + local_file_name = uri.split("/")[-1] + if not os.path.exists(uri): + raise FileNotFoundError(f"File not found: {uri}") + + dummy_metadata = FileMetadata( + file_id=local_file_name, + bucket="dummy_bucket", + file_name=local_file_name, + file_size=-1, + storage_type="local", + storage_path=uri, + uri=uri, + custom_metadata={}, + file_hash="", + ) + logger.info(f"Reading local file: {uri}") + return open(uri, "rb"), dummy_metadata # noqa: SIM115 + + parsed_uri = FileStorageURI.parse(uri) + metadata = self.metadata_storage.load( + FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ), + FileMetadata, + ) + if not metadata: + raise FileNotFoundError(f"No metadata found for URI: {uri}") + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + with root_tracer.start_span( + "file_storage_system.get_file.backend_load", + metadata={ + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_name": metadata.file_name, + "storage_type": metadata.storage_type, + }, + ): + file_data = backend.load(metadata) + + with root_tracer.start_span( + "file_storage_system.get_file.verify_hash", + ): + calculated_hash = self._calculate_file_hash(file_data) + if calculated_hash != "-1" and calculated_hash != metadata.file_hash: + raise ValueError("File integrity check failed. Hash mismatch.") + + return file_data, metadata + + def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]: + """Get the file metadata. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Optional[FileMetadata]: The file metadata + """ + fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket) + return self.metadata_storage.load(fid, FileMetadata) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage backend. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + parsed_uri = FileStorageURI.parse(uri) + fid = FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ) + metadata = self.metadata_storage.load(fid, FileMetadata) + if not metadata: + return False + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + if backend.delete(metadata): + try: + self.metadata_storage.delete(fid) + return True + except Exception: + # If the metadata deletion fails, log the error and return False + return False + return False + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket.""" + filters = filters or {} + filters["bucket"] = bucket + return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata) + + +class FileStorageClient(BaseComponent): + """File storage client component.""" + + name = ComponentType.FILE_STORAGE_CLIENT.value + + def __init__( + self, + system_app: Optional[SystemApp] = None, + storage_system: Optional[FileStorageSystem] = None, + ): + """Initialize the file storage client.""" + super().__init__(system_app=system_app) + if not storage_system: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + storage_system = FileStorageSystem( + { + LocalFileStorage.storage_type: LocalFileStorage( + base_path=str(base_path) + ) + } + ) + + self.system_app = system_app + self._storage_system = storage_system + + def init_app(self, system_app: SystemApp): + """Initialize the application.""" + self.system_app = system_app + + @property + def storage_system(self) -> FileStorageSystem: + """Get the file storage system.""" + if not self._storage_system: + raise ValueError("File storage system not initialized") + return self._storage_system + + def upload_file( + self, + bucket: str, + file_path: str, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Upload a file to the storage system. + + Args: + bucket (str): The bucket name + file_path (str): The file path + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + with open(file_path, "rb") as file: + return self.save_file( + bucket, os.path.basename(file_path), file, storage_type, custom_metadata + ) + + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage system. + + Args: + bucket (str): The bucket name + file_name (str): The file name + file_data (BinaryIO): The file data + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + return self.storage_system.save_file( + bucket, file_name, file_data, storage_type, custom_metadata + ) + + def download_file(self, uri: str, destination_path: str) -> None: + """Download a file from the storage system. + + Args: + uri (str): The file URI + destination_path (str): The destination + + Raises: + FileNotFoundError: If the file is not found + """ + file_data, _ = self.storage_system.get_file(uri) + with open(destination_path, "wb") as f: + f.write(file_data.read()) + + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + return self.storage_system.get_file(uri) + + def get_file_by_id( + self, bucket: str, file_id: str + ) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.get_file(metadata.uri) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + return self.storage_system.delete_file(uri) + + def delete_file_by_id(self, bucket: str, file_id: str) -> bool: + """Delete the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + bool: True if the file was deleted, False otherwise + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.delete_file(metadata.uri) + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket. + + Args: + bucket (str): The bucket name + filters (Dict[str, Any], optional): Filters. Defaults to None. + + Returns: + List[FileMetadata]: The list of file metadata + """ + return self.storage_system.list_files(bucket, filters) + + +class SimpleDistributedStorage(StorageBackend): + """Simple distributed storage backend.""" + + storage_type: str = "distributed" + + def __init__( + self, + node_address: str, + local_storage_path: str, + save_chunk_size: int = 1024 * 1024, + transfer_chunk_size: int = 1024 * 1024, + transfer_timeout: int = 360, + api_prefix: str = "/api/v2/serve/file/files", + ): + """Initialize the simple distributed storage backend.""" + self.node_address = node_address + self.local_storage_path = local_storage_path + os.makedirs(self.local_storage_path, exist_ok=True) + self._save_chunk_size = save_chunk_size + self._transfer_chunk_size = transfer_chunk_size + self._transfer_timeout = transfer_timeout + self._api_prefix = api_prefix + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str: + node_id = hashlib.md5(node_address.encode()).hexdigest() + return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}") + + def _parse_node_address(self, fm: FileMetadata) -> str: + storage_path = fm.storage_path + if not storage_path.startswith("distributed://"): + raise ValueError("Invalid storage path") + return storage_path.split("//")[1].split("/")[0] + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the distributed storage backend. + + Just save the file locally. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + file_path = self._get_file_path(bucket, file_id, self.node_address) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + + return f"distributed://{self.node_address}/{bucket}/{file_id}" + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the distributed storage backend. + + If the file is stored on the local node, load it from the local storage. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + + # TODO: check if the file is cached in local storage + if node_address == self.node_address: + if os.path.exists(file_path): + return open(file_path, "rb") # noqa: SIM115 + else: + raise FileNotFoundError(f"File {file_id} not found on the local node") + else: + response = requests.get( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + stream=True, + ) + response.raise_for_status() + # TODO: cache the file in local storage + return StreamedBytesIO( + response.iter_content(chunk_size=self._transfer_chunk_size) + ) + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the distributed storage backend. + + If the file is stored on the local node, delete it from the local storage. + If the file is stored on a remote node, send a delete request to the remote + node. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + if node_address == self.node_address: + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + else: + try: + response = requests.delete( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + ) + response.raise_for_status() + return True + except Exception: + return False + + +class StreamedBytesIO(io.BytesIO): + """A BytesIO subclass that can be used with streaming responses. + + Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13 + """ + + def __init__(self, request_iterator): + """Initialize the StreamedBytesIO instance.""" + super().__init__() + self._bytes = BytesIO() + self._iterator = request_iterator + + def _load_all(self): + self._bytes.seek(0, io.SEEK_END) + for chunk in self._iterator: + self._bytes.write(chunk) + + def _load_until(self, goal_position): + current_position = self._bytes.seek(0, io.SEEK_END) + while current_position < goal_position: + try: + current_position += self._bytes.write(next(self._iterator)) + except StopIteration: + break + + def tell(self) -> int: + """Get the current position.""" + return self._bytes.tell() + + def read(self, size: Optional[int] = None) -> bytes: + """Read the data from the stream. + + Args: + size (Optional[int], optional): The number of bytes to read. Defaults to + None. + + Returns: + bytes: The read data + """ + left_off_at = self._bytes.tell() + if size is None: + self._load_all() + else: + goal_position = left_off_at + size + self._load_until(goal_position) + + self._bytes.seek(left_off_at) + return self._bytes.read(size) + + def seek(self, position: int, whence: int = io.SEEK_SET): + """Seek to a position in the stream. + + Args: + position (int): The position + whence (int, optional): The reference point. Defaults to io.SEEK + + Raises: + ValueError: If the reference point is invalid + """ + if whence == io.SEEK_END: + self._load_all() + else: + self._bytes.seek(position, whence) + + def __enter__(self): + """Enter the context manager.""" + return self + + def __exit__(self, ext_type, value, tb): + """Exit the context manager.""" + self._bytes.close() diff --git a/dbgpt/core/interface/operators/llm_operator.py b/dbgpt/core/interface/operators/llm_operator.py index 53e34ffe5..45863d0a9 100644 --- a/dbgpt/core/interface/operators/llm_operator.py +++ b/dbgpt/core/interface/operators/llm_operator.py @@ -24,6 +24,7 @@ from dbgpt.core.awel.flow import ( OperatorType, Parameter, ViewMetadata, + ui, ) from dbgpt.core.interface.llm import ( LLMClient, @@ -69,6 +70,10 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest]): optional=True, default=None, description=_("The temperature of the model request."), + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), ), Parameter.build_from( _("Max New Tokens"), diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index c3765aa67..7d97230ac 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -1,4 +1,5 @@ """The prompt operator.""" + from abc import ABC from typing import Any, Dict, List, Optional, Union @@ -18,6 +19,7 @@ from dbgpt.core.awel.flow import ( ResourceCategory, ViewMetadata, register_resource, + ui, ) from dbgpt.core.interface.message import BaseMessage from dbgpt.core.interface.operators.llm_operator import BaseLLM @@ -48,6 +50,7 @@ from dbgpt.util.i18n_utils import _ optional=True, default="You are a helpful AI Assistant.", description=_("The system message."), + ui=ui.DefaultUITextArea(), ), Parameter.build_from( label=_("Message placeholder"), @@ -65,6 +68,7 @@ from dbgpt.util.i18n_utils import _ default="{user_input}", placeholder="{user_input}", description=_("The human message."), + ui=ui.DefaultUITextArea(), ), ], ) diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index 2a61746ec..4bf152ab8 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -3,13 +3,14 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast -from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.util.annotations import PublicAPI from dbgpt.util.i18n_utils import _ from dbgpt.util.pagination_utils import PaginationResult from dbgpt.util.serialization.json_serialization import JsonSerializer +from ..awel.flow import Parameter, ResourceCategory, register_resource + @PublicAPI(stability="beta") class ResourceIdentifier(Serializable, ABC): diff --git a/dbgpt/core/interface/tests/test_file.py b/dbgpt/core/interface/tests/test_file.py new file mode 100644 index 000000000..f6e462944 --- /dev/null +++ b/dbgpt/core/interface/tests/test_file.py @@ -0,0 +1,506 @@ +import hashlib +import io +import os +from unittest import mock + +import pytest + +from ..file import ( + FileMetadata, + FileMetadataIdentifier, + FileStorageClient, + FileStorageSystem, + InMemoryStorage, + LocalFileStorage, + SimpleDistributedStorage, +) + + +@pytest.fixture +def temp_test_file_dir(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def temp_storage_path(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def local_storage_backend(temp_storage_path): + return LocalFileStorage(temp_storage_path) + + +@pytest.fixture +def distributed_storage_backend(temp_storage_path): + node_address = "127.0.0.1:8000" + return SimpleDistributedStorage(node_address, temp_storage_path) + + +@pytest.fixture +def file_storage_system(local_storage_backend): + backends = {"local": local_storage_backend} + metadata_storage = InMemoryStorage() + return FileStorageSystem(backends, metadata_storage) + + +@pytest.fixture +def file_storage_client(file_storage_system): + return FileStorageClient(storage_system=file_storage_system) + + +@pytest.fixture +def sample_file_path(temp_test_file_dir): + file_path = os.path.join(temp_test_file_dir, "sample.txt") + with open(file_path, "wb") as f: + f.write(b"Sample file content") + return file_path + + +@pytest.fixture +def sample_file_data(): + return io.BytesIO(b"Sample file content for distributed storage") + + +def test_save_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert uri.startswith("dbgpt-fs://local/test-bucket/") + assert os.path.exists(sample_file_path) + + +def test_get_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == b"Sample file content" + assert metadata.file_name == "sample.txt" + assert metadata.bucket == bucket + + +def test_delete_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + result = file_storage_client.delete_file(uri) + assert result is True + assert len(file_storage_client.list_files(bucket=bucket)) == 0 + + +def test_list_files(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri1 = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + files = file_storage_client.list_files(bucket=bucket) + assert len(files) == 1 + + +def test_save_file_unsupported_storage(file_storage_system, sample_file_path): + bucket = "test-bucket" + with pytest.raises(ValueError): + file_storage_system.save_file( + bucket=bucket, + file_name="unsupported.txt", + file_data=io.BytesIO(b"Unsupported storage"), + storage_type="unsupported", + ) + + +def test_get_file_not_found(file_storage_system): + with pytest.raises(FileNotFoundError): + file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent") + + +def test_delete_file_not_found(file_storage_system): + result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent") + assert result is False + + +def test_metadata_management(file_storage_system): + bucket = "test-bucket" + file_id = "test_file" + metadata = file_storage_system.metadata_storage.save( + FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=100, + storage_type="local", + storage_path="/path/to/test.txt", + uri="dbgpt-fs://local/test-bucket/test_file", + custom_metadata={"key": "value"}, + file_hash="hash", + ) + ) + + loaded_metadata = file_storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata + ) + assert loaded_metadata.file_name == "test.txt" + assert loaded_metadata.custom_metadata["key"] == "value" + assert loaded_metadata.bucket == bucket + + +def test_concurrent_save_and_delete(file_storage_client, sample_file_path): + bucket = "test-bucket" + + # Simulate concurrent file save and delete operations + def save_file(): + return file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + def delete_file(uri): + return file_storage_client.delete_file(uri) + + uri = save_file() + + # Simulate concurrent operations + save_file() + delete_file(uri) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + + +def test_large_file_handling(file_storage_client, temp_storage_path): + bucket = "test-bucket" + large_file_path = os.path.join(temp_storage_path, "large_sample.bin") + with open(large_file_path, "wb") as f: + f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file + + uri = file_storage_client.upload_file( + bucket=bucket, + file_path=large_file_path, + storage_type="local", + custom_metadata={"description": "Large file test"}, + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == open(large_file_path, "rb").read() + assert metadata.file_name == "large_sample.bin" + assert metadata.bucket == bucket + + +def test_file_hash_verification_success(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + file_data, metadata = file_storage_client.storage_system.get_file(uri) + file_hash = metadata.file_hash + calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data) + + assert ( + file_hash == calculated_hash + ), "File hash should match after saving and loading" + + +def test_file_hash_verification_failure(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + # Modify the file content manually to simulate file tampering + storage_system = file_storage_client.storage_system + metadata = storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata + ) + with open(metadata.storage_path, "wb") as f: + f.write(b"Tampered content") + + # Get file should raise an exception due to hash mismatch + with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."): + storage_system.get_file(uri) + + +def test_file_isolation_across_buckets(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Verify both URIs are different and point to different files + assert uri1 != uri2 + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + assert file_data1.read() == b"Sample file content" + assert file_data2.read() == b"Sample file content" + assert metadata1.bucket == bucket1 + assert metadata2.bucket == bucket2 + + +def test_list_files_in_specific_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload a file to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # List files in bucket1 and bucket2 + files_in_bucket1 = file_storage_client.list_files(bucket=bucket1) + files_in_bucket2 = file_storage_client.list_files(bucket=bucket2) + + assert len(files_in_bucket1) == 1 + assert len(files_in_bucket2) == 1 + assert files_in_bucket1[0].bucket == bucket1 + assert files_in_bucket2[0].bucket == bucket2 + + +def test_delete_file_in_one_bucket_does_not_affect_other_bucket( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete the file in bucket1 + file_storage_client.delete_file(uri1) + + # Check that the file in bucket1 is deleted + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Check that the file in bucket2 is still there + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + assert file_data2.read() == b"Sample file content" + + +def test_file_hash_verification_in_different_buckets( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + # Verify that file hashes are the same for the same content + file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1) + file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2) + + assert file_hash1 == metadata1.file_hash + assert file_hash2 == metadata2.file_hash + assert file_hash1 == file_hash2 + + +def test_file_download_from_different_buckets( + file_storage_client, sample_file_path, temp_storage_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Download files to different locations + download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt") + download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt") + + file_storage_client.download_file(uri1, download_path1) + file_storage_client.download_file(uri2, download_path2) + + # Verify contents of downloaded files + assert open(download_path1, "rb").read() == b"Sample file content" + assert open(download_path2, "rb").read() == b"Sample file content" + + +def test_delete_all_files_in_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload files to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete all files in bucket1 + for file in file_storage_client.list_files(bucket=bucket1): + file_storage_client.delete_file(file.uri) + + # Verify bucket1 is empty + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Verify bucket2 still has files + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + + +def test_simple_distributed_storage_save_file( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data) + + expected_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}" + assert os.path.exists(expected_path) + + +def test_simple_distributed_storage_load_file_local( + distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + + +@mock.patch("requests.get") +def test_simple_distributed_storage_load_file_remote( + mock_get, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + # Mock the response from remote node + mock_response = mock.Mock() + mock_response.iter_content = mock.Mock( + return_value=iter([b"Sample file content for distributed storage"]) + ) + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_get.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + mock_get.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + stream=True, + timeout=360, + ) + + +def test_simple_distributed_storage_delete_file_local( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + file_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert result is True + assert not os.path.exists(file_path) + + +@mock.patch("requests.delete") +def test_simple_distributed_storage_delete_file_remote( + mock_delete, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + mock_response = mock.Mock() + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_delete.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + assert result is True + mock_delete.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + timeout=360, + ) diff --git a/dbgpt/core/interface/tests/test_variables.py b/dbgpt/core/interface/tests/test_variables.py new file mode 100644 index 000000000..313657b4e --- /dev/null +++ b/dbgpt/core/interface/tests/test_variables.py @@ -0,0 +1,327 @@ +import base64 +import os +from itertools import product + +from cryptography.fernet import Fernet + +from ..variables import ( + FernetEncryption, + InMemoryStorage, + SimpleEncryption, + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, + build_variable_string, + parse_variable, +) + + +def test_fernet_encryption(): + key = Fernet.generate_key() + encryption = FernetEncryption(key) + new_encryption = FernetEncryption(key) + data = "test_data" + salt = "test_salt" + + encrypted_data = encryption.encrypt(data, salt) + assert encrypted_data != data + + decrypted_data = encryption.decrypt(encrypted_data, salt) + assert decrypted_data == data + assert decrypted_data == new_encryption.decrypt(encrypted_data, salt) + + +def test_simple_encryption(): + key = base64.b64encode(os.urandom(32)).decode() + encryption = SimpleEncryption(key) + data = "test_data" + salt = "test_salt" + + encrypted_data = encryption.encrypt(data, salt) + assert encrypted_data != data + + decrypted_data = encryption.decrypt(encrypted_data, salt) + assert decrypted_data == data + + +def test_storage_variables_provider(): + storage = InMemoryStorage() + encryption = SimpleEncryption() + provider = StorageVariablesProvider(storage, encryption) + + full_key = "${key:name@global}" + value = "secret_value" + value_type = "str" + label = "test_label" + + id = VariablesIdentifier.from_str_identifier(full_key) + provider.save( + StorageVariables.from_identifier( + id, value, value_type, label, category="secret" + ) + ) + + loaded_variable_value = provider.get(full_key) + assert loaded_variable_value == value + + +def test_variables_identifier(): + full_key = "${key:name@global:scope_key#sys_code%user_name}" + identifier = VariablesIdentifier.from_str_identifier(full_key) + + assert identifier.key == "key" + assert identifier.name == "name" + assert identifier.scope == "global" + assert identifier.scope_key == "scope_key" + assert identifier.sys_code == "sys_code" + assert identifier.user_name == "user_name" + + str_identifier = identifier.str_identifier + assert str_identifier == full_key + + +def test_storage_variables(): + key = "test_key" + name = "test_name" + label = "test_label" + value = "test_value" + value_type = "str" + category = "common" + scope = "global" + + storage_variable = StorageVariables( + key=key, + name=name, + label=label, + value=value, + value_type=value_type, + category=category, + scope=scope, + ) + + assert storage_variable.key == key + assert storage_variable.name == name + assert storage_variable.label == label + assert storage_variable.value == value + assert storage_variable.value_type == value_type + assert storage_variable.category == category + assert storage_variable.scope == scope + + dict_representation = storage_variable.to_dict() + assert dict_representation["key"] == key + assert dict_representation["name"] == name + assert dict_representation["label"] == label + assert dict_representation["value"] == value + assert dict_representation["value_type"] == value_type + assert dict_representation["category"] == category + assert dict_representation["scope"] == scope + + +def generate_test_cases(enable_escape=False): + # Define possible values for each field, including special characters for escaping + _EMPTY_ = "___EMPTY___" + fields = { + "name": [ + None, + "test_name", + "test:name" if enable_escape else _EMPTY_, + "test::name" if enable_escape else _EMPTY_, + "test#name" if enable_escape else _EMPTY_, + "test##name" if enable_escape else _EMPTY_, + "test::@@@#22name" if enable_escape else _EMPTY_, + ], + "scope": [ + None, + "test_scope", + "test@scope" if enable_escape else _EMPTY_, + "test@@scope" if enable_escape else _EMPTY_, + "test:scope" if enable_escape else _EMPTY_, + "test:#:scope" if enable_escape else _EMPTY_, + ], + "scope_key": [ + None, + "test_scope_key", + "test:scope_key" if enable_escape else _EMPTY_, + ], + "sys_code": [ + None, + "test_sys_code", + "test#sys_code" if enable_escape else _EMPTY_, + ], + "user_name": [ + None, + "test_user_name", + "test%user_name" if enable_escape else _EMPTY_, + ], + } + # Remove empty values + fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()} + + # Generate all possible combinations + combinations = product(*fields.values()) + + test_cases = [] + for combo in combinations: + name, scope, scope_key, sys_code, user_name = combo + + var_str = build_variable_string( + { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + }, + enable_escape=enable_escape, + ) + + # Construct the expected output + expected = { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + } + + test_cases.append((var_str, expected, enable_escape)) + + return test_cases + + +def test_parse_variables(): + # Run test cases without escape + test_cases = generate_test_cases(enable_escape=False) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + result = parse_variable(input_str, enable_escape=enable_escape) + assert result == expected_output, f"Test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_test_cases(enable_escape=True) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + print(f"input_str: {input_str}, expected_output: {expected_output}") + result = parse_variable(input_str, enable_escape=enable_escape) + assert result == expected_output, f"Test case {i} failed with escape" + + +def generate_build_test_cases(enable_escape=False): + # Define possible values for each field, including special characters for escaping + _EMPTY_ = "___EMPTY___" + fields = { + "name": [ + None, + "test_name", + "test:name" if enable_escape else _EMPTY_, + "test::name" if enable_escape else _EMPTY_, + "test\name" if enable_escape else _EMPTY_, + "test\\name" if enable_escape else _EMPTY_, + "test\:\#\@\%name" if enable_escape else _EMPTY_, + "test\::\###\@@\%%name" if enable_escape else _EMPTY_, + "test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_, + "test\:#:name" if enable_escape else _EMPTY_, + ], + "scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_], + "scope_key": [ + None, + "test_scope_key", + "test:scope_key" if enable_escape else _EMPTY_, + ], + "sys_code": [ + None, + "test_sys_code", + "test#sys_code" if enable_escape else _EMPTY_, + ], + "user_name": [ + None, + "test_user_name", + "test%user_name" if enable_escape else _EMPTY_, + ], + } + # Remove empty values + fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()} + + # Generate all possible combinations + combinations = product(*fields.values()) + + test_cases = [] + + def escape_special_chars(s): + if not enable_escape or s is None: + return s + return ( + s.replace(":", "\\:") + .replace("@", "\\@") + .replace("%", "\\%") + .replace("#", "\\#") + ) + + for combo in combinations: + name, scope, scope_key, sys_code, user_name = combo + + # Construct the input dictionary + input_dict = { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + } + input_dict_with_escape = { + k: escape_special_chars(v) for k, v in input_dict.items() + } + + # Construct the expected variable string + expected_str = "${test_key" + if name: + expected_str += f":{input_dict_with_escape['name']}" + if scope or scope_key: + expected_str += "@" + if scope: + expected_str += input_dict_with_escape["scope"] + if scope_key: + expected_str += f":{input_dict_with_escape['scope_key']}" + if sys_code: + expected_str += f"#{input_dict_with_escape['sys_code']}" + if user_name: + expected_str += f"%{input_dict_with_escape['user_name']}" + expected_str += "}" + + test_cases.append((input_dict, expected_str, enable_escape)) + + return test_cases + + +def test_build_variable_string(): + # Run test cases without escape + test_cases = generate_build_test_cases(enable_escape=False) + for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1): + result = build_variable_string(input_dict, enable_escape=enable_escape) + assert result == expected_str, f"Test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_build_test_cases(enable_escape=True) + for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1): + print(f"input_dict: {input_dict}, expected_str: {expected_str}") + result = build_variable_string(input_dict, enable_escape=enable_escape) + assert result == expected_str, f"Test case {i} failed with escape" + + +def test_variable_string_round_trip(): + # Run test cases without escape + test_cases = generate_test_cases(enable_escape=False) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + parsed_result = parse_variable(input_str, enable_escape=enable_escape) + built_result = build_variable_string(parsed_result, enable_escape=enable_escape) + assert ( + built_result == input_str + ), f"Round trip test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_test_cases(enable_escape=True) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + parsed_result = parse_variable(input_str, enable_escape=enable_escape) + built_result = build_variable_string(parsed_result, enable_escape=enable_escape) + assert built_result == input_str, f"Round trip test case {i} failed with escape" diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py new file mode 100644 index 000000000..22e035d52 --- /dev/null +++ b/dbgpt/core/interface/variables.py @@ -0,0 +1,979 @@ +"""Variables Module.""" + +import base64 +import dataclasses +import hashlib +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.executor_utils import ( + DefaultExecutorFactory, + blocking_func_to_async, + blocking_func_to_async_no_executor, +) + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageInterface, + StorageItem, +) + +_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE" + +BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows" +BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes" +BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables" +BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets" +BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms" +BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings" +BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers" +BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources" +BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents" +BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces" + + +class Encryption(ABC): + """Encryption interface.""" + + name: str = "__abstract__" + + @abstractmethod + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data.""" + + @abstractmethod + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data.""" + + +def _generate_key_from_password( + password: bytes, salt: Optional[Union[str, bytes]] = None +): + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + if salt is None: + salt = os.urandom(16) + elif isinstance(salt, str): + salt = salt.encode() + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(password)) + return key, salt + + +class FernetEncryption(Encryption): + """Fernet encryption. + + A symmetric encryption algorithm that uses the same key for both encryption and + decryption which is powered by the cryptography library. + """ + + name = "fernet" + + def __init__(self, key: Optional[bytes] = None): + """Initialize the fernet encryption.""" + if key is not None and isinstance(key, str): + key = key.encode() + try: + from cryptography.fernet import Fernet + except ImportError: + raise ImportError( + "cryptography is required for encryption, please install by running " + "`pip install cryptography`" + ) + if key is None: + key = Fernet.generate_key() + self.key = key + + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data with the salt. + + Args: + data (str): The data to encrypt. + salt (str): The salt to use, which is used to derive the key. + + Returns: + str: The encrypted data. + """ + from cryptography.fernet import Fernet + + key, salt = _generate_key_from_password(self.key, salt) + fernet = Fernet(key) + encrypted_secret = fernet.encrypt(data.encode()).decode() + return encrypted_secret + + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data with the salt. + + Args: + encrypted_data (str): The encrypted data. + salt (str): The salt to use, which is used to derive the key. + + Returns: + str: The decrypted data. + """ + from cryptography.fernet import Fernet + + key, salt = _generate_key_from_password(self.key, salt) + fernet = Fernet(key) + return fernet.decrypt(encrypted_data.encode()).decode() + + +class SimpleEncryption(Encryption): + """Simple implementation of encryption. + + A simple encryption algorithm that uses a key to XOR the data. + """ + + name = "simple" + + def __init__(self, key: Optional[str] = None): + """Initialize the simple encryption.""" + if key is None: + key = base64.b64encode(os.urandom(32)).decode() + self.key = key + + def _derive_key(self, salt: str) -> bytes: + return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000) + + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data with the salt.""" + key = self._derive_key(salt) + encrypted = bytes( + x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1)) + ) + return base64.b64encode(encrypted).decode() + + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data with the salt.""" + key = self._derive_key(salt) + data = base64.b64decode(encrypted_data) + decrypted = bytes( + x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1)) + ) + return decrypted.decode() + + +@dataclasses.dataclass +class VariablesIdentifier(ResourceIdentifier): + """The variables identifier.""" + + identifier_split: str = dataclasses.field(default="@", init=False) + + key: str + name: str + scope: str = "global" + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + + def __post_init__(self): + """Post init method.""" + if not self.key or not self.name or not self.scope: + raise ValueError("Key, name, and scope are required.") + + @property + def str_identifier(self) -> str: + """Return the string identifier of the identifier.""" + return build_variable_string( + { + "key": self.key, + "name": self.name, + "scope": self.scope, + "scope_key": self.scope_key, + "sys_code": self.sys_code, + "user_name": self.user_name, + } + ) + + def to_dict(self) -> Dict: + """Convert the identifier to a dict. + + Returns: + Dict: The dict of the identifier. + """ + return { + "key": self.key, + "name": self.name, + "scope": self.scope, + "scope_key": self.scope_key, + "sys_code": self.sys_code, + "user_name": self.user_name, + } + + @classmethod + def from_str_identifier( + cls, + str_identifier: str, + default_identifier_map: Optional[Dict[str, str]] = None, + ) -> "VariablesIdentifier": + """Create a VariablesIdentifier from a string identifier. + + Args: + str_identifier (str): The string identifier. + default_identifier_map (Optional[Dict[str, str]]): The default identifier + map, which contains the default values for the identifier. Defaults to + None. + + Returns: + VariablesIdentifier: The VariablesIdentifier. + """ + variable_dict = parse_variable(str_identifier) + if not variable_dict: + raise ValueError("Invalid string identifier.") + if not variable_dict.get("key"): + raise ValueError("Invalid string identifier, must have key") + if not variable_dict.get("name"): + raise ValueError("Invalid string identifier, must have name") + + def _get_value(key, default_value: Optional[str] = None) -> Optional[str]: + if variable_dict.get(key) is not None: + return variable_dict.get(key) + if default_identifier_map is not None and default_identifier_map.get(key): + return default_identifier_map.get(key) + return default_value + + return cls( + key=variable_dict["key"], + name=variable_dict["name"], + scope=variable_dict["scope"], + scope_key=_get_value("scope_key"), + sys_code=_get_value("sys_code"), + user_name=_get_value("user_name"), + ) + + +@dataclasses.dataclass +class StorageVariables(StorageItem): + """The storage variables.""" + + key: str + name: str + label: str + value: Any + category: Literal["common", "secret"] = "common" + scope: str = "global" + value_type: Optional[str] = None + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + encryption_method: Optional[str] = None + salt: Optional[str] = None + enabled: int = 1 + description: Optional[str] = None + + _identifier: VariablesIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = VariablesIdentifier( + key=self.key, + name=self.name, + scope=self.scope, + scope_key=self.scope_key, + sys_code=self.sys_code, + user_name=self.user_name, + ) + if not self.value_type: + self.value_type = type(self.value).__name__ + + @property + def identifier(self) -> ResourceIdentifier: + """Return the identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge with another storage variables.""" + if not isinstance(other, StorageVariables): + raise ValueError(f"Cannot merge with {type(other)}") + self.from_object(other) + + def to_dict(self) -> Dict: + """Convert the storage variables to a dict. + + Returns: + Dict: The dict of the storage variables. + """ + return { + **self._identifier.to_dict(), + "label": self.label, + "value": self.value, + "value_type": self.value_type, + "category": self.category, + "encryption_method": self.encryption_method, + "salt": self.salt, + "enabled": self.enabled, + "description": self.description, + } + + def from_object(self, other: "StorageVariables") -> None: + """Copy the values from another storage variables object.""" + self.label = other.label + self.value = other.value + self.value_type = other.value_type + self.category = other.category + self.scope = other.scope + self.scope_key = other.scope_key + self.sys_code = other.sys_code + self.user_name = other.user_name + self.encryption_method = other.encryption_method + self.salt = other.salt + self.enabled = other.enabled + self.description = other.description + + @classmethod + def from_identifier( + cls, + identifier: VariablesIdentifier, + value: Any, + value_type: str, + label: str = "", + category: Literal["common", "secret"] = "common", + encryption_method: Optional[str] = None, + salt: Optional[str] = None, + ) -> "StorageVariables": + """Copy the values from an identifier.""" + return cls( + key=identifier.key, + name=identifier.name, + label=label, + value=value, + value_type=value_type, + category=category, + scope=identifier.scope, + scope_key=identifier.scope_key, + sys_code=identifier.sys_code, + user_name=identifier.user_name, + encryption_method=encryption_method, + salt=salt, + ) + + +class VariablesProvider(BaseComponent, ABC): + """The variables provider interface.""" + + name = ComponentType.VARIABLES_PROVIDER.value + + @abstractmethod + def get( + self, + full_key: str, + default_value: Optional[str] = _EMPTY_DEFAULT_VALUE, + default_identifier_map: Optional[Dict[str, str]] = None, + ) -> Any: + """Query variables from storage.""" + + @abstractmethod + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + + @abstractmethod + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get variables by key.""" + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get variables by key async.""" + raise NotImplementedError("Current variables provider does not support async.") + + def support_async(self) -> bool: + """Whether the variables provider support async.""" + return False + + def _convert_to_value_type(self, var: StorageVariables): + """Convert the variable to the value type.""" + if var.value is None: + return None + if var.value_type == "str": + return str(var.value) + elif var.value_type == "int": + return int(var.value) + elif var.value_type == "float": + return float(var.value) + elif var.value_type == "bool": + if var.value.lower() in ["true", "1"]: + return True + elif var.value.lower() in ["false", "0"]: + return False + else: + return bool(var.value) + else: + return var.value + + +class VariablesPlaceHolder: + """The variables place holder.""" + + def __init__( + self, + param_name: str, + full_key: str, + default_value: Any = _EMPTY_DEFAULT_VALUE, + ): + """Initialize the variables place holder.""" + self.param_name = param_name + self.full_key = full_key + self.default_value = default_value + + def parse( + self, + variables_provider: VariablesProvider, + ignore_not_found_error: bool = False, + default_identifier_map: Optional[Dict[str, str]] = None, + ): + """Parse the variables.""" + try: + return variables_provider.get( + self.full_key, + self.default_value, + default_identifier_map=default_identifier_map, + ) + except ValueError as e: + if ignore_not_found_error: + return None + raise e + + def __repr__(self): + """Return the representation of the variables place holder.""" + return f"" + + +class StorageVariablesProvider(VariablesProvider): + """The storage variables provider.""" + + def __init__( + self, + storage: Optional[StorageInterface] = None, + encryption: Optional[Encryption] = None, + system_app: Optional[SystemApp] = None, + key: Optional[str] = None, + ): + """Initialize the storage variables provider.""" + if storage is None: + storage = InMemoryStorage() + self.system_app = system_app + self.encryption = encryption or SimpleEncryption(key) + + self.storage = storage + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the storage variables provider.""" + self.system_app = system_app + + def get( + self, + full_key: str, + default_value: Optional[str] = _EMPTY_DEFAULT_VALUE, + default_identifier_map: Optional[Dict[str, str]] = None, + ) -> Any: + """Query variables from storage.""" + key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map) + variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables) + if variable is None: + if default_value == _EMPTY_DEFAULT_VALUE: + raise ValueError(f"Variable {full_key} not found") + return default_value + variable.value = self.deserialize_value(variable.value) + if ( + variable.value is not None + and variable.category == "secret" + and variable.encryption_method + and variable.salt + ): + variable.value = self.encryption.decrypt(variable.value, variable.salt) + return self._convert_to_value_type(variable) + + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + if variables_item.category == "secret": + salt = base64.b64encode(os.urandom(16)).decode() + variables_item.value = self.encryption.encrypt( + str(variables_item.value), salt + ) + variables_item.encryption_method = self.encryption.name + variables_item.salt = salt + # Replace value to a json serializable object + variables_item.value = self.serialize_value(variables_item.value) + + self.storage.save_or_update(variables_item) + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Query variables from storage.""" + # Try to get builtin variables + is_builtin, builtin_variables = self._get_builtins_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + if is_builtin: + return builtin_variables + variables = self.storage.query( + QuerySpec( + conditions={ + "key": key, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + "enabled": 1, + } + ), + StorageVariables, + ) + for variable in variables: + variable.value = self.deserialize_value(variable.value) + return variables + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Query variables from storage async.""" + # Try to get builtin variables + is_builtin, builtin_variables = await self._async_get_builtins_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + if is_builtin: + return builtin_variables + executor_factory: Optional[ + DefaultExecutorFactory + ] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None) + if executor_factory: + return await blocking_func_to_async( + executor_factory.create(), + self.get_variables, + key, + scope, + scope_key, + sys_code, + user_name, + ) + else: + return await blocking_func_to_async_no_executor( + self.get_variables, key, scope, scope_key, sys_code, user_name + ) + + def _get_builtins_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> Tuple[bool, List[StorageVariables]]: + """Get builtin variables.""" + if self.system_app is None: + return False, [] + provider: BuiltinVariablesProvider = self.system_app.get_component( + key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False, [] + return True, provider.get_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + + async def _async_get_builtins_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> Tuple[bool, List[StorageVariables]]: + """Get builtin variables.""" + if self.system_app is None: + return False, [] + provider: BuiltinVariablesProvider = self.system_app.get_component( + key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False, [] + if not provider.support_async(): + return False, [] + return True, await provider.async_get_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + + @classmethod + def serialize_value(cls, value: Any) -> str: + """Serialize the value.""" + value_dict = {"value": value} + return json.dumps(value_dict, ensure_ascii=False) + + @classmethod + def deserialize_value(cls, value: str) -> Any: + """Deserialize the value.""" + value_dict = json.loads(value) + return value_dict["value"] + + +class BuiltinVariablesProvider(VariablesProvider, ABC): + """The builtin variables provider. + + You can implement this class to provide builtin variables. Such LLMs, agents, + datasource, knowledge base, etc. + """ + + name = "dbgpt_variables_builtin" + + def __init__(self, system_app: Optional[SystemApp] = None): + """Initialize the builtin variables provider.""" + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the builtin variables provider.""" + self.system_app = system_app + + def get( + self, + full_key: str, + default_value: Optional[str] = _EMPTY_DEFAULT_VALUE, + default_identifier_map: Optional[Dict[str, str]] = None, + ) -> Any: + """Query variables from storage.""" + raise NotImplementedError("BuiltinVariablesProvider does not support get.") + + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + raise NotImplementedError("BuiltinVariablesProvider does not support save.") + + +def parse_variable( + variable_str: str, + enable_escape: bool = True, +) -> Dict[str, Any]: + """Parse the variable string. + + Examples: + .. code-block:: python + + cases = [ + { + "full_key": "${test_key:test_name@test_scope:test_scope_key}", + "expected": { + "key": "test_key", + "name": "test_name", + "scope": "test_scope", + "scope_key": "test_scope_key", + "sys_code": None, + "user_name": None, + }, + }, + { + "full_key": "${test_key#test_sys_code}", + "expected": { + "key": "test_key", + "name": None, + "scope": None, + "scope_key": None, + "sys_code": "test_sys_code", + "user_name": None, + }, + }, + { + "full_key": "${test_key@:test_scope_key}", + "expected": { + "key": "test_key", + "name": None, + "scope": None, + "scope_key": "test_scope_key", + "sys_code": None, + "user_name": None, + }, + }, + ] + for case in cases: + assert parse_variable(case["full_key"]) == case["expected"] + Args: + variable_str (str): The variable string. + enable_escape (bool): Whether to handle escaped characters. + Returns: + Dict[str, Any]: The parsed variable. + """ + if not variable_str.startswith("${") or not variable_str.endswith("}"): + raise ValueError( + "Invalid variable format, must start with '${' and end with '}'" + ) + + # Remove the surrounding ${ and } + content = variable_str[2:-1] + + # Define placeholders for escaped characters + placeholders = { + r"\@": "__ESCAPED_AT__", + r"\#": "__ESCAPED_HASH__", + r"\%": "__ESCAPED_PERCENT__", + r"\:": "__ESCAPED_COLON__", + } + + if enable_escape: + # Replace escaped characters with placeholders + for original, placeholder in placeholders.items(): + content = content.replace(original, placeholder) + + # Initialize the result dictionary + result: Dict[str, Optional[str]] = { + "key": None, + "name": None, + "scope": None, + "scope_key": None, + "sys_code": None, + "user_name": None, + } + + # Split the content by special characters + parts = content.split("@") + + # Parse key and name + key_name = parts[0].split("#")[0].split("%")[0] + if ":" in key_name: + result["key"], result["name"] = key_name.split(":", 1) + else: + result["key"] = key_name + + # Parse scope and scope_key + if len(parts) > 1: + scope_part = parts[1].split("#")[0].split("%")[0] + if ":" in scope_part: + result["scope"], result["scope_key"] = scope_part.split(":", 1) + else: + result["scope"] = scope_part + + # Parse sys_code + if "#" in content: + result["sys_code"] = content.split("#", 1)[1].split("%")[0] + + # Parse user_name + if "%" in content: + result["user_name"] = content.split("%", 1)[1] + + if enable_escape: + # Replace placeholders back with escaped characters + reverse_placeholders = {v: k[1:] for k, v in placeholders.items()} + for key, value in result.items(): + if value: + for placeholder, original in reverse_placeholders.items(): + result[key] = result[key].replace( # type: ignore + placeholder, original + ) + + # Replace empty strings with None + for key, value in result.items(): + if value == "": + result[key] = None + + return result + + +def _is_variable_format(value: str) -> bool: + if not value.startswith("${") or not value.endswith("}"): + return False + return True + + +def is_variable_string(variable_str: str) -> bool: + """Check if the given string is a variable string. + + A valid variable string should start with "${" and end with "}", and contain key + and name + + Args: + variable_str (str): The string to check. + + Returns: + bool: True if the string is a variable string, False otherwise. + """ + if not variable_str or not isinstance(variable_str, str): + return False + if not _is_variable_format(variable_str): + return False + try: + variable_dict = parse_variable(variable_str) + if not variable_dict.get("key"): + return False + if not variable_dict.get("name"): + return False + return True + except Exception: + return False + + +def is_variable_list_string(variable_str: str) -> bool: + """Check if the given string is a variable string. + + A valid variable list string should start with "${" and end with "}", and contain + key and not contain name + + A valid variable list string means that the variable is a list of variables with the + same key. + + Args: + variable_str (str): The string to check. + + Returns: + bool: True if the string is a variable string, False otherwise. + """ + if not _is_variable_format(variable_str): + return False + try: + variable_dict = parse_variable(variable_str) + if not variable_dict.get("key"): + return False + if variable_dict.get("name"): + return False + return True + except Exception: + return False + + +def build_variable_string( + variable_dict: Dict[str, Any], + scope_sig: str = "@", + sys_code_sig: str = "#", + user_sig: str = "%", + kv_sig: str = ":", + enable_escape: bool = True, +) -> str: + """Build a variable string from the given dictionary. + + Args: + variable_dict (Dict[str, Any]): The dictionary containing the variable details. + scope_sig (str): The scope signature. + sys_code_sig (str): The sys code signature. + user_sig (str): The user signature. + kv_sig (str): The key-value split signature. + enable_escape (bool): Whether to escape special characters + + Returns: + str: The formatted variable string. + + Examples: + >>> build_variable_string( + ... { + ... "key": "test_key", + ... "name": "test_name", + ... "scope": "test_scope", + ... "scope_key": "test_scope_key", + ... "sys_code": "test_sys_code", + ... "user_name": "test_user", + ... } + ... ) + '${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}' + + >>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"}) + '${test_key@:test_scope_key}' + + >>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"}) + '${test_key#test_sys_code}' + + >>> build_variable_string({"key": "test_key"}) + '${test_key}' + """ + special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig} + # Replace None with "" + new_variable_dict = {key: value or "" for key, value in variable_dict.items()} + + # Check if the variable_dict contains any special characters + for key, value in new_variable_dict.items(): + if value != "" and any(char in value for char in special_chars): + if enable_escape: + # Escape special characters + new_variable_dict[key] = ( + value.replace("@", "\\@") + .replace("#", "\\#") + .replace("%", "\\%") + .replace(":", "\\:") + ) + else: + raise ValueError( + f"{key} contains special characters, error value: {value}, special " + f"characters: {special_chars}" + ) + + key = new_variable_dict.get("key", "") + name = new_variable_dict.get("name", "") + scope = new_variable_dict.get("scope", "") + scope_key = new_variable_dict.get("scope_key", "") + sys_code = new_variable_dict.get("sys_code", "") + user_name = new_variable_dict.get("user_name", "") + + # Construct the base of the variable string + variable_str = f"${{{key}" + + # Add name if present + if name: + variable_str += f":{name}" + + # Add scope and scope_key if present + if scope or scope_key: + variable_str += f"@{scope}" + if scope_key: + variable_str += f":{scope_key}" + + # Add sys_code if present + if sys_code: + variable_str += f"#{sys_code}" + + # Add user_name if present + if user_name: + variable_str += f"%{user_name}" + + # Close the variable string + variable_str += "}" + + return variable_str diff --git a/dbgpt/serve/agent/app/recommend_question/recommend_question.py b/dbgpt/serve/agent/app/recommend_question/recommend_question.py index 789965a54..6356dd9f8 100644 --- a/dbgpt/serve/agent/app/recommend_question/recommend_question.py +++ b/dbgpt/serve/agent/app/recommend_question/recommend_question.py @@ -87,7 +87,7 @@ class RecommendQuestionEntity(Model): default=False, comment="hot question would be displayed on the main page.", ) - __table_args__ = (Index("idx_app_code", "app_code"),) + __table_args__ = (Index("idx_rec_q_app_code", "app_code"),) class RecommendQuestionDao(BaseDao): diff --git a/dbgpt/serve/agent/db/gpts_app.py b/dbgpt/serve/agent/db/gpts_app.py index a744c2247..7a0b6a148 100644 --- a/dbgpt/serve/agent/db/gpts_app.py +++ b/dbgpt/serve/agent/db/gpts_app.py @@ -268,7 +268,7 @@ class UserRecentAppsEntity(Model): ) last_accessed = Column(DateTime, default=None, comment="last access time") __table_args__ = ( - Index("idx_app_code", "app_code"), + Index("idx_user_r_app_code", "app_code"), Index("idx_user_code", "user_code"), Index("idx_last_accessed", "last_accessed"), ) @@ -451,9 +451,11 @@ class UserRecentAppsDao(BaseDao): "sys_code": sys_code, "user_code": user_code, "last_accessed": last_accessed, - "gmt_create": existing_app.gmt_create - if existing_app - else new_app.gmt_create, + "gmt_create": ( + existing_app.gmt_create + if existing_app + else new_app.gmt_create + ), "gmt_modified": last_accessed, } ) @@ -655,9 +657,9 @@ class GptsAppDao(BaseDao): apps = sorted( apps, - key=lambda obj: float("-inf") - if obj.hot_value is None - else obj.hot_value, + key=lambda obj: ( + float("-inf") if obj.hot_value is None else obj.hot_value + ), reverse=True, ) app_resp.total_count = total_count @@ -696,19 +698,19 @@ class GptsAppDao(BaseDao): for item in app_details ], "published": app_info.published, - "param_need": json.loads(app_info.param_need) - if app_info.param_need - else None, - "hot_value": hot_app_map.get(app_info.app_code, 0) - if hot_app_map is not None - else 0, + "param_need": ( + json.loads(app_info.param_need) if app_info.param_need else None + ), + "hot_value": ( + hot_app_map.get(app_info.app_code, 0) if hot_app_map is not None else 0 + ), "owner_name": app_info.user_code, "owner_avatar_url": owner_avatar_url, - "recommend_questions": [ - RecommendQuestion.from_entity(item) for item in recommend_questions - ] - if recommend_questions - else [], + "recommend_questions": ( + [RecommendQuestion.from_entity(item) for item in recommend_questions] + if recommend_questions + else [] + ), "admins": [], } @@ -848,9 +850,9 @@ class GptsAppDao(BaseDao): updated_at=gpts_app.updated_at, icon=gpts_app.icon, published="true" if gpts_app.published else "false", - param_need=json.dumps(gpts_app.param_need) - if gpts_app.param_need - else None, + param_need=( + json.dumps(gpts_app.param_need) if gpts_app.param_need else None + ), ) session.add(app_entity) @@ -869,9 +871,11 @@ class GptsAppDao(BaseDao): resources=json.dumps(resource_dicts, ensure_ascii=False), prompt_template=item.prompt_template, llm_strategy=item.llm_strategy, - llm_strategy_value=None - if item.llm_strategy_value is None - else json.dumps(tuple(item.llm_strategy_value.split(","))), + llm_strategy_value=( + None + if item.llm_strategy_value is None + else json.dumps(tuple(item.llm_strategy_value.split(","))) + ), created_at=item.created_at, updated_at=item.updated_at, ) @@ -915,7 +919,7 @@ class GptsAppDao(BaseDao): app_entity.team_mode = gpts_app.team_mode app_entity.icon = gpts_app.icon app_entity.team_context = _parse_team_context(gpts_app.team_context) - app_entity.param_need = (json.dumps(gpts_app.param_need),) + app_entity.param_need = json.dumps(gpts_app.param_need) session.merge(app_entity) old_details = session.query(GptsAppDetailEntity).filter( @@ -936,9 +940,11 @@ class GptsAppDao(BaseDao): resources=json.dumps(resource_dicts, ensure_ascii=False), prompt_template=item.prompt_template, llm_strategy=item.llm_strategy, - llm_strategy_value=None - if item.llm_strategy_value is None - else json.dumps(tuple(item.llm_strategy_value.split(","))), + llm_strategy_value=( + None + if item.llm_strategy_value is None + else json.dumps(tuple(item.llm_strategy_value.split(","))) + ), created_at=item.created_at, updated_at=item.updated_at, ) diff --git a/dbgpt/serve/core/__init__.py b/dbgpt/serve/core/__init__.py index 090288128..31edd5d6c 100644 --- a/dbgpt/serve/core/__init__.py +++ b/dbgpt/serve/core/__init__.py @@ -1,7 +1,11 @@ +from typing import Any + from dbgpt.serve.core.config import BaseServeConfig from dbgpt.serve.core.schemas import Result, add_exception_handler from dbgpt.serve.core.serve import BaseServe from dbgpt.serve.core.service import BaseService +from dbgpt.util.executor_utils import BlockingFunction, DefaultExecutorFactory +from dbgpt.util.executor_utils import blocking_func_to_async as _blocking_func_to_async __ALL__ = [ "Result", @@ -10,3 +14,11 @@ __ALL__ = [ "BaseService", "BaseServe", ] + + +async def blocking_func_to_async( + system_app, func: BlockingFunction, *args, **kwargs +) -> Any: + """Run a potentially blocking function within an executor.""" + executor = DefaultExecutorFactory.get_instance(system_app).create() + return await _blocking_func_to_async(executor, func, *args, **kwargs) diff --git a/dbgpt/serve/file/__init__.py b/dbgpt/serve/file/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/__init__.py b/dbgpt/serve/file/api/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py new file mode 100644 index 000000000..26bbb9673 --- /dev/null +++ b/dbgpt/serve/file/api/endpoints.py @@ -0,0 +1,169 @@ +import logging +from functools import cache +from typing import List, Optional +from urllib.parse import quote + +from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import StreamingResponse + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result, blocking_func_to_async +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ServeRequest, ServerResponse, UploadFileResponse + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/files/{bucket}", + response_model=Result[List[UploadFileResponse]], + dependencies=[Depends(check_api_key)], +) +async def upload_files( + bucket: str, + files: List[UploadFile], + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + service: Service = Depends(get_service), +) -> Result[List[UploadFileResponse]]: + """Upload files by a list of UploadFile.""" + logger.info(f"upload_files: bucket={bucket}, files={files}") + results = await blocking_func_to_async( + global_system_app, + service.upload_files, + bucket, + "distributed", + files, + user_name, + sys_code, + ) + return Result.succ(results) + + +@router.get("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def download_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Download a file by file_id.""" + logger.info(f"download_file: bucket={bucket}, file_id={file_id}") + file_data, file_metadata = await blocking_func_to_async( + global_system_app, service.download_file, bucket, file_id + ) + file_name_encoded = quote(file_metadata.file_name) + + def file_iterator(raw_iter): + with raw_iter: + while chunk := raw_iter.read( + service.config.file_server_download_chunk_size + ): + yield chunk + + response = StreamingResponse( + file_iterator(file_data), media_type="application/octet-stream" + ) + response.headers[ + "Content-Disposition" + ] = f"attachment; filename={file_name_encoded}" + return response + + +@router.delete("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def delete_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Delete a file by file_id.""" + await blocking_func_to_async( + global_system_app, service.delete_file, bucket, file_id + ) + return Result.succ(None) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/file/api/schemas.py b/dbgpt/serve/file/api/schemas.py new file mode 100644 index 000000000..911f71db3 --- /dev/null +++ b/dbgpt/serve/file/api/schemas.py @@ -0,0 +1,43 @@ +# Define your Pydantic schemas here +from typing import Any, Dict + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """File request model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class ServerResponse(BaseModel): + """File response model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class UploadFileResponse(BaseModel): + """Upload file response model""" + + file_name: str = Field(..., title="The name of the uploaded file") + file_id: str = Field(..., title="The ID of the uploaded file") + bucket: str = Field(..., title="The bucket of the uploaded file") + uri: str = Field(..., title="The URI of the uploaded file") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) diff --git a/dbgpt/serve/file/config.py b/dbgpt/serve/file/config.py new file mode 100644 index 000000000..1ab1afede --- /dev/null +++ b/dbgpt/serve/file/config.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "file" +SERVE_APP_NAME = "dbgpt_serve_file" +SERVE_APP_NAME_HUMP = "dbgpt_serve_File" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.file." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_file" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) + check_hash: Optional[bool] = field( + default=True, metadata={"help": "Check the hash of the file when downloading"} + ) + file_server_host: Optional[str] = field( + default=None, metadata={"help": "The host of the file server"} + ) + file_server_port: Optional[int] = field( + default=5670, metadata={"help": "The port of the file server"} + ) + file_server_download_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when downloading the file"}, + ) + file_server_save_chunk_size: Optional[int] = field( + default=1024 * 1024, metadata={"help": "The chunk size when saving the file"} + ) + file_server_transfer_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when transferring the file"}, + ) + file_server_transfer_timeout: Optional[int] = field( + default=360, metadata={"help": "The timeout when transferring the file"} + ) + local_storage_path: Optional[str] = field( + default=None, metadata={"help": "The local storage path"} + ) + + def get_node_address(self) -> str: + """Get the node address""" + file_server_host = self.file_server_host + if not file_server_host: + from dbgpt.util.net_utils import _get_ip_address + + file_server_host = _get_ip_address() + file_server_port = self.file_server_port or 5670 + return f"{file_server_host}:{file_server_port}" + + def get_local_storage_path(self) -> str: + """Get the local storage path""" + local_storage_path = self.local_storage_path + if not local_storage_path: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + local_storage_path = str(base_path) + return local_storage_path diff --git a/dbgpt/serve/file/dependencies.py b/dbgpt/serve/file/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/file/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/file/models/__init__.py b/dbgpt/serve/file/models/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/models/file_adapter.py b/dbgpt/serve/file/models/file_adapter.py new file mode 100644 index 000000000..29ee831f4 --- /dev/null +++ b/dbgpt/serve/file/models/file_adapter.py @@ -0,0 +1,88 @@ +import json +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.file import FileMetadata, FileMetadataIdentifier +from dbgpt.core.interface.storage import StorageItemAdapter + +from .models import ServeEntity + + +class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]): + """File metadata adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: FileMetadata) -> ServeEntity: + """Convert to storage format.""" + custom_metadata = ( + {k: v for k, v in item.custom_metadata.items()} + if item.custom_metadata + else {} + ) + user_name = item.user_name or custom_metadata.get("user_name") + sys_code = item.sys_code or custom_metadata.get("sys_code") + if "user_name" in custom_metadata: + del custom_metadata["user_name"] + if "sys_code" in custom_metadata: + del custom_metadata["sys_code"] + custom_metadata_json = ( + json.dumps(custom_metadata, ensure_ascii=False) if custom_metadata else None + ) + return ServeEntity( + bucket=item.bucket, + file_id=item.file_id, + file_name=item.file_name, + file_size=item.file_size, + storage_type=item.storage_type, + storage_path=item.storage_path, + uri=item.uri, + custom_metadata=custom_metadata_json, + file_hash=item.file_hash, + user_name=user_name, + sys_code=sys_code, + ) + + def from_storage_format(self, model: ServeEntity) -> FileMetadata: + """Convert from storage format.""" + custom_metadata = ( + json.loads(model.custom_metadata) if model.custom_metadata else None + ) + if custom_metadata is None: + custom_metadata = {} + if model.user_name: + custom_metadata["user_name"] = model.user_name + if model.sys_code: + custom_metadata["sys_code"] = model.sys_code + + return FileMetadata( + bucket=model.bucket, + file_id=model.file_id, + file_name=model.file_name, + file_size=model.file_size, + storage_type=model.storage_type, + storage_path=model.storage_path, + uri=model.uri, + custom_metadata=custom_metadata, + file_hash=model.file_hash, + user_name=model.user_name, + sys_code=model.sys_code, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ServeEntity], + resource_id: FileMetadataIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + return ( + session.query(storage_format) + .filter(storage_format.bucket == resource_id.bucket) + .filter(storage_format.file_id == resource_id.file_id) + ) diff --git a/dbgpt/serve/file/models/models.py b/dbgpt/serve/file/models/models.py new file mode 100644 index 000000000..fd816740d --- /dev/null +++ b/dbgpt/serve/file/models/models.py @@ -0,0 +1,90 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" + +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVER_APP_TABLE_NAME, ServeConfig + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + __table_args__ = (UniqueConstraint("bucket", "file_id", name="uk_bucket_file_id"),) + + id = Column(Integer, primary_key=True, comment="Auto increment id") + + bucket = Column(String(255), nullable=False, comment="Bucket name") + file_id = Column(String(255), nullable=False, comment="File id") + file_name = Column(String(256), nullable=False, comment="File name") + file_size = Column(Integer, nullable=True, comment="File size") + storage_type = Column(String(32), nullable=False, comment="Storage type") + storage_path = Column(String(512), nullable=False, comment="Storage path") + uri = Column(String(512), nullable=False, comment="File URI") + custom_metadata = Column( + Text, nullable=True, comment="Custom metadata, JSON format" + ) + file_hash = Column(String(128), nullable=True, comment="File hash") + user_name = Column(String(128), index=True, nullable=True, comment="User name") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return ( + f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', " + f"gmt_modified='{self.gmt_modified}')" + ) + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for File""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) + entity = ServeEntity(**request_dict) + # TODO implement your own logic here, transfer the request_dict to an entity + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + # TODO implement your own logic here, transfer the entity to a request + return ServeRequest() + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + return ServerResponse() diff --git a/dbgpt/serve/file/serve.py b/dbgpt/serve/file/serve.py new file mode 100644 index 000000000..559509573 --- /dev/null +++ b/dbgpt/serve/file/serve.py @@ -0,0 +1,113 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.core.interface.file import FileStorageClient +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v2/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + self._db_manager: Optional[DatabaseManager] = None + self._file_storage_client: Optional[FileStorageClient] = None + self._serve_config: Optional[ServeConfig] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + from dbgpt.core.interface.file import ( + FileStorageSystem, + SimpleDistributedStorage, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.file_adapter import FileMetadataAdapter + from .models.models import ServeEntity + + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + ServeEntity, + FileMetadataAdapter(), + serializer, + ) + simple_distributed_storage = SimpleDistributedStorage( + node_address=self._serve_config.get_node_address(), + local_storage_path=self._serve_config.get_local_storage_path(), + save_chunk_size=self._serve_config.file_server_save_chunk_size, + transfer_chunk_size=self._serve_config.file_server_transfer_chunk_size, + transfer_timeout=self._serve_config.file_server_transfer_timeout, + ) + storage_backends = { + simple_distributed_storage.storage_type: simple_distributed_storage, + } + fs = FileStorageSystem( + storage_backends, + metadata_storage=storage, + check_hash=self._serve_config.check_hash, + ) + self._file_storage_client = FileStorageClient( + system_app=self._system_app, storage_system=fs + ) + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the file storage client.""" + if not self._file_storage_client: + raise ValueError("File storage client is not initialized") + return self._file_storage_client diff --git a/dbgpt/serve/file/service/__init__.py b/dbgpt/serve/file/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py new file mode 100644 index 000000000..13e8b6225 --- /dev/null +++ b/dbgpt/serve/file/service/service.py @@ -0,0 +1,119 @@ +import logging +from typing import BinaryIO, List, Optional, Tuple + +from fastapi import UploadFile + +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.tracer import root_tracer, trace + +from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + +logger = logging.getLogger(__name__) + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for File""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the internal FileStorageClient. + + Returns: + FileStorageClient: The internal FileStorageClient + """ + file_storage_client = FileStorageClient.get_instance( + self._system_app, default_component=None + ) + if file_storage_client: + return file_storage_client + else: + from ..serve import Serve + + file_storage_client = Serve.get_instance( + self._system_app + ).file_storage_client + self._system_app.register_instance(file_storage_client) + return file_storage_client + + @trace("upload_files") + def upload_files( + self, + bucket: str, + storage_type: str, + files: List[UploadFile], + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + ) -> List[UploadFileResponse]: + """Upload files by a list of UploadFile.""" + results = [] + for file in files: + file_name = file.filename + logger.info(f"Uploading file {file_name} to bucket {bucket}") + custom_metadata = { + "user_name": user_name, + "sys_code": sys_code, + } + uri = self.file_storage_client.save_file( + bucket, + file_name, + file_data=file.file, + storage_type=storage_type, + custom_metadata=custom_metadata, + ) + parsed_uri = FileStorageURI.parse(uri) + logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}") + results.append( + UploadFileResponse( + file_name=file_name, + file_id=parsed_uri.file_id, + bucket=bucket, + uri=uri, + ) + ) + return results + + @trace("download_file") + def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetadata]: + """Download a file by file_id.""" + return self.file_storage_client.get_file_by_id(bucket, file_id) + + def delete_file(self, bucket: str, file_id: str) -> None: + """Delete a file by file_id.""" + self.file_storage_client.delete_file_by_id(bucket, file_id) diff --git a/dbgpt/serve/file/tests/__init__.py b/dbgpt/serve/file/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/tests/test_endpoints.py b/dbgpt/serve/file/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/file/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_models.py b/dbgpt/serve/file/tests/test_models.py new file mode 100644 index 000000000..8b66e9f97 --- /dev/null +++ b/dbgpt/serve/file/tests/test_models.py @@ -0,0 +1,99 @@ +import pytest + +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig +from ..models.models import ServeDao, ServeEntity + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_service.py b/dbgpt/serve/file/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/file/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 0ddc583fe..28f05d532 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,18 +1,29 @@ +import io +import json from functools import cache -from typing import List, Optional, Union +from typing import Dict, List, Literal, Optional, Union -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import JSONResponse, StreamingResponse from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata from dbgpt.core.awel.flow.flow_factory import FlowCategory -from dbgpt.serve.core import Result +from dbgpt.serve.core import Result, blocking_func_to_async from dbgpt.util import PaginationResult from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import ServeRequest, ServerResponse +from ..service.variables_service import VariablesService +from .schemas import ( + FlowDebugRequest, + RefreshNodeRequest, + ServeRequest, + ServerResponse, + VariablesRequest, + VariablesResponse, +) router = APIRouter() @@ -23,7 +34,12 @@ global_system_app: Optional[SystemApp] = None def get_service() -> Service: """Get the service instance""" - return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + return Service.get_instance(global_system_app) + + +def get_variable_service() -> VariablesService: + """Get the service instance""" + return VariablesService.get_instance(global_system_app) get_bearer_token = HTTPBearer(auto_error=False) @@ -102,7 +118,9 @@ async def test_auth(): @router.post( - "/flows", response_model=Result[None], dependencies=[Depends(check_api_key)] + "/flows", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], ) async def create( request: ServeRequest, service: Service = Depends(get_service) @@ -239,20 +257,236 @@ async def query_page( @router.get("/nodes", dependencies=[Depends(check_api_key)]) -async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]: +async def get_nodes( + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + tags: Optional[str] = Query(default=None, description="tags"), +): """Get the operator or resource nodes + Args: + user_name (Optional[str]): The username + sys_code (Optional[str]): The system code + tags (Optional[str]): The tags encoded in JSON format + Returns: Result[List[Union[ViewMetadata, ResourceMetadata]]]: The operator or resource nodes """ from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY - return Result.succ(_OPERATOR_REGISTRY.metadata_list()) + tags_dict: Optional[Dict[str, str]] = None + if tags: + try: + tags_dict = json.loads(tags) + except json.JSONDecodeError: + return Result.fail("Invalid JSON format for tags") + + metadata_list = await blocking_func_to_async( + global_system_app, + _OPERATOR_REGISTRY.metadata_list, + tags_dict, + user_name, + sys_code, + ) + return Result.succ(metadata_list) + + +@router.post("/nodes/refresh", dependencies=[Depends(check_api_key)]) +async def refresh_nodes(refresh_request: RefreshNodeRequest): + """Refresh the operator or resource nodes + + Returns: + Result[None]: The response + """ + from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY + + # Make sure the variables provider is initialized + _ = get_variable_service().variables_provider + + new_metadata = await _OPERATOR_REGISTRY.refresh( + refresh_request.id, + refresh_request.flow_type == "operator", + refresh_request.refresh, + "http", + global_system_app, + ) + return Result.succ(new_metadata) + + +@router.post( + "/variables", + response_model=Result[VariablesResponse], + dependencies=[Depends(check_api_key)], +) +async def create_variables( + variables_request: VariablesRequest, +) -> Result[VariablesResponse]: + """Create a new Variables entity + + Args: + variables_request (VariablesRequest): The request + Returns: + VariablesResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, get_variable_service().create, variables_request + ) + return Result.succ(res) + + +@router.put( + "/variables/{v_id}", + response_model=Result[VariablesResponse], + dependencies=[Depends(check_api_key)], +) +async def update_variables( + v_id: int, variables_request: VariablesRequest +) -> Result[VariablesResponse]: + """Update a Variables entity + + Args: + v_id (int): The variable id + variables_request (VariablesRequest): The request + Returns: + VariablesResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, get_variable_service().update, v_id, variables_request + ) + return Result.succ(res) + + +@router.post("/flow/debug", dependencies=[Depends(check_api_key)]) +async def debug_flow( + flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service) +): + """Run the flow in debug mode.""" + # Return the no-incremental stream by default + stream_iter = service.debug_flow(flow_debug_request, default_incremental=False) + + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + return StreamingResponse( + service._wrapper_chat_stream_flow_str(stream_iter), + headers=headers, + media_type="text/event-stream", + ) + + +@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)]) +async def export_flow( + uid: str, + export_type: Literal["json", "dbgpts"] = Query( + "json", description="export type(json or dbgpts)" + ), + format: Literal["file", "json"] = Query( + "file", description="response format(file or json)" + ), + file_name: Optional[str] = Query(default=None, description="file name to export"), + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + service: Service = Depends(get_service), +): + """Export the flow to a file.""" + flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code}) + if not flow: + raise HTTPException(status_code=404, detail=f"Flow {uid} not found") + package_name = flow.name.replace("_", "-") + file_name = file_name or package_name + if export_type == "json": + flow_dict = {"flow": flow.to_dict()} + if format == "json": + return JSONResponse(content=flow_dict) + else: + # Return the json file + return StreamingResponse( + io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")), + media_type="application/file", + headers={ + "Content-Disposition": f"attachment;filename={file_name}.json" + }, + ) + + elif export_type == "dbgpts": + from ..service.share_utils import _generate_dbgpts_zip + + if format == "json": + raise HTTPException( + status_code=400, detail="json response is not supported for dbgpts" + ) + + zip_buffer = await blocking_func_to_async( + global_system_app, _generate_dbgpts_zip, package_name, flow + ) + return StreamingResponse( + zip_buffer, + media_type="application/x-zip-compressed", + headers={"Content-Disposition": f"attachment;filename={file_name}.zip"}, + ) + + +@router.post( + "/flow/import", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def import_flow( + file: UploadFile = File(...), + save_flow: bool = Query( + False, description="Whether to save the flow after importing" + ), + service: Service = Depends(get_service), +): + """Import the flow from a file.""" + filename = file.filename + file_extension = filename.split(".")[-1].lower() + if file_extension == "json": + # Handle json file + json_content = await file.read() + json_dict = json.loads(json_content) + if "flow" not in json_dict: + raise HTTPException( + status_code=400, detail="invalid json file, missing 'flow' key" + ) + flow = ServeRequest.parse_obj(json_dict["flow"]) + elif file_extension == "zip": + from ..service.share_utils import _parse_flow_from_zip_file + + # Handle zip file + flow = await _parse_flow_from_zip_file(file, global_system_app) + else: + raise HTTPException( + status_code=400, detail=f"invalid file extension {file_extension}" + ) + if save_flow: + return Result.succ(service.create_and_save_dag(flow)) + else: + return Result.succ(flow) def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" + from .variables_provider import ( + BuiltinAllSecretVariablesProvider, + BuiltinAllVariablesProvider, + BuiltinEmbeddingsVariablesProvider, + BuiltinFlowVariablesProvider, + BuiltinLLMVariablesProvider, + BuiltinNodeVariablesProvider, + ) + global global_system_app system_app.register(Service) + system_app.register(VariablesService) + system_app.register(BuiltinFlowVariablesProvider) + system_app.register(BuiltinNodeVariablesProvider) + system_app.register(BuiltinAllVariablesProvider) + system_app.register(BuiltinAllSecretVariablesProvider) + system_app.register(BuiltinLLMVariablesProvider) + system_app.register(BuiltinEmbeddingsVariablesProvider) global_system_app = system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 6fb8c1924..cf82de982 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,7 +1,9 @@ -from dbgpt._private.pydantic import ConfigDict +from typing import Any, Dict, List, Literal, Optional, Union -# Define your Pydantic schemas here -from dbgpt.core.awel.flow.flow_factory import FlowPanel +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field +from dbgpt.core.awel import CommonLLMHttpRequestBody +from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest +from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -14,3 +16,71 @@ class ServerResponse(FlowPanel): # TODO define your own fields here model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + +class VariablesResponse(VariablesRequest): + """Variable response model.""" + + id: int = Field( + ..., + description="The id of the variable", + examples=[1], + ) + + +class RefreshNodeRequest(BaseModel): + """Flow response model""" + + model_config = ConfigDict(title=f"RefreshNodeRequest") + id: str = Field( + ..., + title="The id of the node", + description="The id of the node to refresh", + examples=["operator_llm_operator___$$___llm___$$___v1"], + ) + flow_type: Literal["operator", "resource"] = Field( + "operator", + title="The type of the node", + description="The type of the node to refresh", + examples=["operator", "resource"], + ) + type_name: str = Field( + ..., + title="The type of the node", + description="The type of the node to refresh", + examples=["LLMOperator"], + ) + type_cls: str = Field( + ..., + title="The class of the node", + description="The class of the node to refresh", + examples=["dbgpt.core.operator.llm.LLMOperator"], + ) + refresh: List[RefreshOptionRequest] = Field( + ..., + title="The refresh options", + description="The refresh options", + ) + + +class FlowDebugRequest(BaseModel): + """Flow response model""" + + model_config = ConfigDict(title=f"FlowDebugRequest") + flow: ServeRequest = Field( + ..., + title="The flow to debug", + description="The flow to debug", + ) + request: Union[CommonLLMHttpRequestBody, Dict[str, Any]] = Field( + ..., + title="The request to debug", + description="The request to debug", + ) + variables: Optional[Dict[str, Any]] = Field( + None, + title="The variables to debug", + description="The variables to debug", + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py new file mode 100644 index 000000000..4728f80e6 --- /dev/null +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -0,0 +1,260 @@ +from typing import List, Literal, Optional + +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, + BuiltinVariablesProvider, + StorageVariables, +) + +from ..service.service import Service +from .endpoints import get_service, get_variable_service +from .schemas import ServerResponse + + +class BuiltinFlowVariablesProvider(BuiltinVariablesProvider): + """Builtin flow variables provider. + + Provide all flows by variables "${dbgpt.core.flow.flows}" + """ + + name = BUILTIN_VARIABLES_CORE_FLOWS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + service: Service = get_service() + page_result = service.get_list_by_page( + { + "user_name": user_name, + "sys_code": sys_code, + }, + 1, + 1000, + ) + flows: List[ServerResponse] = page_result.items + variables = [] + for flow in flows: + variables.append( + StorageVariables( + key=key, + name=flow.name, + label=flow.label, + value=flow.uid, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + +class BuiltinNodeVariablesProvider(BuiltinVariablesProvider): + """Builtin node variables provider. + + Provide all nodes by variables "${dbgpt.core.flow.nodes}" + """ + + name = BUILTIN_VARIABLES_CORE_FLOW_NODES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY + + metadata_list = _OPERATOR_REGISTRY.metadata_list() + variables = [] + for metadata in metadata_list: + variables.append( + StorageVariables( + key=key, + name=metadata["name"], + label=metadata["label"], + value=metadata["id"], + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + +class BuiltinAllVariablesProvider(BuiltinVariablesProvider): + """Builtin all variables provider. + + Provide all variables by variables "${dbgpt.core.variables}" + """ + + name = BUILTIN_VARIABLES_CORE_VARIABLES + + def _get_variables_from_db( + self, + key: str, + scope: str, + scope_key: Optional[str], + sys_code: Optional[str], + user_name: Optional[str], + category: Literal["common", "secret"] = "common", + ) -> List[StorageVariables]: + storage_variables = get_variable_service().list_all_variables(category) + variables = [] + for var in storage_variables: + variables.append( + StorageVariables( + key=key, + name=var.name, + label=var.label, + value=var.value, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables. + + TODO: Return all builtin variables + """ + return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name) + + +class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider): + """Builtin all secret variables provider. + + Provide all secret variables by variables "${dbgpt.core.secrets}" + """ + + name = BUILTIN_VARIABLES_CORE_SECRETS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return self._get_variables_from_db( + key, scope, scope_key, sys_code, user_name, "secret" + ) + + +class BuiltinLLMVariablesProvider(BuiltinVariablesProvider): + """Builtin LLM variables provider. + + Provide all LLM variables by variables "${dbgpt.core.llmv}" + """ + + name = BUILTIN_VARIABLES_CORE_LLMS + + def support_async(self) -> bool: + """Whether the dynamic options support async.""" + return True + + async def _get_models( + self, + key: str, + scope: str, + scope_key: Optional[str], + sys_code: Optional[str], + user_name: Optional[str], + expect_worker_type: str = "llm", + ) -> List[StorageVariables]: + from dbgpt.model.cluster.controller.controller import BaseModelController + + controller = BaseModelController.get_instance(self.system_app) + models = await controller.get_all_instances(healthy_only=True) + model_dict = {} + for model in models: + worker_name, worker_type = model.model_name.split("@") + if expect_worker_type == worker_type: + model_dict[worker_name] = model + variables = [] + for worker_name, model in model_dict.items(): + variables.append( + StorageVariables( + key=key, + name=worker_name, + label=worker_name, + value=worker_name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return await self._get_models(key, scope, scope_key, sys_code, user_name) + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + raise NotImplementedError( + "Not implemented get variables sync, please use async_get_variables" + ) + + +class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider): + """Builtin embeddings variables provider. + + Provide all embeddings variables by variables "${dbgpt.core.embeddings}" + """ + + name = BUILTIN_VARIABLES_CORE_EMBEDDINGS + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return await self._get_models( + key, scope, scope_key, sys_code, user_name, "text2vec" + ) diff --git a/dbgpt/serve/flow/config.py b/dbgpt/serve/flow/config.py index 97eea7478..0cc35667d 100644 --- a/dbgpt/serve/flow/config.py +++ b/dbgpt/serve/flow/config.py @@ -8,8 +8,10 @@ SERVE_APP_NAME = "dbgpt_serve_flow" SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow" SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow." SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service" # Database table name SERVER_APP_TABLE_NAME = "dbgpt_serve_flow" +SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables" @dataclass @@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig): load_dbgpts_interval: int = field( default=5, metadata={"help": "Interval to load dbgpts from installed packages"} ) + encrypt_key: Optional[str] = field( + default=None, metadata={"help": "The key to encrypt the data"} + ) diff --git a/dbgpt/serve/flow/models/models.py b/dbgpt/serve/flow/models/models.py index ea4c7f3ea..0dae2b6cf 100644 --- a/dbgpt/serve/flow/models/models.py +++ b/dbgpt/serve/flow/models/models.py @@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint from dbgpt._private.pydantic import model_to_dict from dbgpt.core.awel.flow.flow_factory import State +from dbgpt.core.interface.variables import StorageVariablesProvider from dbgpt.storage.metadata import BaseDao, Model from dbgpt.storage.metadata._base_dao import QUERY_SPEC -from ..api.schemas import ServeRequest, ServerResponse -from ..config import SERVER_APP_TABLE_NAME, ServeConfig +from ..api.schemas import ( + ServeRequest, + ServerResponse, + VariablesRequest, + VariablesResponse, +) +from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig class ServeEntity(Model): @@ -43,6 +49,7 @@ class ServeEntity(Model): editable = Column( Integer, nullable=True, comment="Editable, 0: editable, 1: not editable" ) + variables = Column(Text, nullable=True, comment="Flow variables, JSON format") user_name = Column(String(128), index=True, nullable=True, comment="User name") sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") @@ -74,6 +81,57 @@ class ServeEntity(Model): return editable is None or editable == 0 +class VariablesEntity(Model): + __tablename__ = SERVER_APP_VARIABLES_TABLE_NAME + + id = Column(Integer, primary_key=True, comment="Auto increment id") + key = Column(String(128), index=True, nullable=False, comment="Variable key") + name = Column(String(128), index=True, nullable=True, comment="Variable name") + label = Column(String(128), nullable=True, comment="Variable label") + value = Column(Text, nullable=True, comment="Variable value, JSON format") + value_type = Column( + String(32), + nullable=True, + comment="Variable value type(string, int, float, bool)", + ) + category = Column( + String(32), + default="common", + nullable=True, + comment="Variable category(common or secret)", + ) + encryption_method = Column( + String(32), + nullable=True, + comment="Variable encryption method(fernet, simple, rsa, aes)", + ) + salt = Column(String(128), nullable=True, comment="Variable salt") + scope = Column( + String(32), + default="global", + nullable=True, + comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, " + "etc)", + ) + scope_key = Column( + String(256), + nullable=True, + comment="Variable scope key, default is empty, for scope is 'flow_priv', " + "the scope_key is dag id of flow", + ) + enabled = Column( + Integer, + default=1, + nullable=True, + comment="Variable enabled, 0: disabled, 1: enabled", + ) + description = Column(Text, nullable=True, comment="Variable description") + user_name = Column(String(128), index=True, nullable=True, comment="User name") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): """The DAO class for Flow""" @@ -98,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): error_message = request_dict.get("error_message") if error_message: error_message = error_message[:500] + + variables_raw = request_dict.get("variables") + variables = ( + json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None + ) new_dict = { "uid": request_dict.get("uid"), "dag_id": request_dict.get("dag_id"), @@ -113,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): "define_type": request_dict.get("define_type"), "editable": ServeEntity.parse_editable(request_dict.get("editable")), "description": request_dict.get("description"), + "variables": variables, "user_name": request_dict.get("user_name"), "sys_code": request_dict.get("sys_code"), } @@ -129,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): REQ: The request """ flow_data = json.loads(entity.flow_data) + variables_raw = json.loads(entity.variables) if entity.variables else None + variables = ServeRequest.parse_variables(variables_raw) return ServeRequest( uid=entity.uid, dag_id=entity.dag_id, @@ -144,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): define_type=entity.define_type, editable=ServeEntity.to_bool_editable(entity.editable), description=entity.description, + variables=variables, user_name=entity.user_name, sys_code=entity.sys_code, ) @@ -160,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): flow_data = json.loads(entity.flow_data) gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") + variables_raw = json.loads(entity.variables) if entity.variables else None + variables = ServeRequest.parse_variables(variables_raw) return ServerResponse( uid=entity.uid, dag_id=entity.dag_id, @@ -175,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): version=entity.version, editable=ServeEntity.to_bool_editable(entity.editable), define_type=entity.define_type, + variables=variables, user_name=entity.user_name, sys_code=entity.sys_code, gmt_created=gmt_created_str, @@ -215,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): entry.editable = ServeEntity.parse_editable(update_request.editable) if update_request.define_type: entry.define_type = update_request.define_type + + if update_request.variables: + variables_raw = update_request.get_variables_dict() + entry.variables = ( + json.dumps(variables_raw, ensure_ascii=False) + if variables_raw + else None + ) if update_request.user_name: entry.user_name = update_request.user_name if update_request.sys_code: @@ -222,3 +300,111 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): session.merge(entry) session.commit() return self.get_one(query_request) + + +class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]): + """The DAO class for Variables""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request( + self, request: Union[VariablesRequest, Dict[str, Any]] + ) -> VariablesEntity: + """Convert the request to an entity + + Args: + request (Union[VariablesRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + model_to_dict(request) if isinstance(request, VariablesRequest) else request + ) + value = StorageVariablesProvider.serialize_value(request_dict.get("value")) + enabled = 1 if request_dict.get("enabled", True) else 0 + new_dict = { + "key": request_dict.get("key"), + "name": request_dict.get("name"), + "label": request_dict.get("label"), + "value": value, + "value_type": request_dict.get("value_type"), + "category": request_dict.get("category"), + "encryption_method": request_dict.get("encryption_method"), + "salt": request_dict.get("salt"), + "scope": request_dict.get("scope"), + "scope_key": request_dict.get("scope_key"), + "enabled": enabled, + "user_name": request_dict.get("user_name"), + "sys_code": request_dict.get("sys_code"), + "description": request_dict.get("description"), + } + entity = VariablesEntity(**new_dict) + return entity + + def to_request(self, entity: VariablesEntity) -> VariablesRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + value = StorageVariablesProvider.deserialize_value(entity.value) + if entity.category == "secret": + value = "******" + enabled = entity.enabled == 1 + return VariablesRequest( + key=entity.key, + name=entity.name, + label=entity.label, + value=value, + value_type=entity.value_type, + category=entity.category, + encryption_method=entity.encryption_method, + salt=entity.salt, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=enabled, + user_name=entity.user_name, + sys_code=entity.sys_code, + description=entity.description, + ) + + def to_response(self, entity: VariablesEntity) -> VariablesResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + value = StorageVariablesProvider.deserialize_value(entity.value) + if entity.category == "secret": + value = "******" + gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") + gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") + enabled = entity.enabled == 1 + return VariablesResponse( + id=entity.id, + key=entity.key, + name=entity.name, + label=entity.label, + value=value, + value_type=entity.value_type, + category=entity.category, + encryption_method=entity.encryption_method, + salt=entity.salt, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=enabled, + user_name=entity.user_name, + sys_code=entity.sys_code, + gmt_created=gmt_created_str, + gmt_modified=gmt_modified_str, + description=entity.description, + ) diff --git a/dbgpt/serve/flow/models/variables_adapter.py b/dbgpt/serve/flow/models/variables_adapter.py new file mode 100644 index 000000000..8c1bda50a --- /dev/null +++ b/dbgpt/serve/flow/models/variables_adapter.py @@ -0,0 +1,71 @@ +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.storage import StorageItemAdapter +from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier + +from .models import VariablesEntity + + +class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]): + """Variables adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: StorageVariables) -> VariablesEntity: + """Convert to storage format.""" + return VariablesEntity( + key=item.key, + name=item.name, + label=item.label, + value=item.value, + value_type=item.value_type, + category=item.category, + encryption_method=item.encryption_method, + salt=item.salt, + scope=item.scope, + scope_key=item.scope_key, + sys_code=item.sys_code, + user_name=item.user_name, + description=item.description, + ) + + def from_storage_format(self, model: VariablesEntity) -> StorageVariables: + """Convert from storage format.""" + return StorageVariables( + key=model.key, + name=model.name, + label=model.label, + value=model.value, + value_type=model.value_type, + category=model.category, + encryption_method=model.encryption_method, + salt=model.salt, + scope=model.scope, + scope_key=model.scope_key, + sys_code=model.sys_code, + user_name=model.user_name, + description=model.description, + ) + + def get_query_for_identifier( + self, + storage_format: Type[VariablesEntity], + resource_id: VariablesIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + query_obj = session.query(VariablesEntity) + for key, value in resource_id.to_dict().items(): + if value is None: + continue + query_obj = query_obj.filter(getattr(VariablesEntity, key) == value) + + # enabled must be True + query_obj = query_obj.filter(VariablesEntity.enabled == 1) + return query_obj diff --git a/dbgpt/serve/flow/serve.py b/dbgpt/serve/flow/serve.py index 126841e57..a27e3d28f 100644 --- a/dbgpt/serve/flow/serve.py +++ b/dbgpt/serve/flow/serve.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union from sqlalchemy import URL from dbgpt.component import SystemApp +from dbgpt.core.interface.variables import VariablesProvider from dbgpt.serve.core import BaseServe from dbgpt.storage.metadata import DatabaseManager @@ -40,6 +41,8 @@ class Serve(BaseServe): system_app, api_prefix, api_tags, db_url_or_db, try_create_tables ) self._db_manager: Optional[DatabaseManager] = None + self._variables_provider: Optional[VariablesProvider] = None + self._serve_config: Optional[ServeConfig] = None def init_app(self, system_app: SystemApp): if self._app_has_initiated: @@ -62,5 +65,37 @@ class Serve(BaseServe): def before_start(self): """Called before the start of the application.""" - # TODO: Your code here + from dbgpt.core.interface.variables import ( + FernetEncryption, + StorageVariablesProvider, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.models import ServeEntity, VariablesEntity + from .models.variables_adapter import VariablesAdapter + self._db_manager = self.create_or_get_db_manager() + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + storage_adapter = VariablesAdapter() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + VariablesEntity, + storage_adapter, + serializer, + ) + self._variables_provider = StorageVariablesProvider( + storage=storage, + encryption=FernetEncryption(self._serve_config.encrypt_key), + system_app=self._system_app, + ) + + @property + def variables_provider(self): + """Get the variables provider of the serve app with db storage""" + return self._variables_provider diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 044a22610..1b8a83422 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json from dbgpt.agent import AgentDummyTrigger from dbgpt.component import SystemApp from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody -from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.core.awel.flow.flow_factory import ( FlowCategory, FlowFactory, @@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC from dbgpt.util.dbgpts.loader import DBGPTsLoader from dbgpt.util.pagination_utils import PaginationResult -from ..api.schemas import ServeRequest, ServerResponse +from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): raise ValueError( f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}" ) from e - res = self.dao.create(request) + self.dao.create(request) + # Query from database + res = self.get({"uid": request.uid}) state = request.state try: @@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): return FlowCategory.CHAT_FLOW except Exception: return FlowCategory.COMMON + + async def debug_flow( + self, request: FlowDebugRequest, default_incremental: Optional[bool] = None + ) -> AsyncIterator[ModelOutput]: + """Debug the flow. + + Args: + request (FlowDebugRequest): The request + default_incremental (Optional[bool]): The default incremental configuration + + Returns: + AsyncIterator[ModelOutput]: The output + """ + from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata + + dag = self._flow_factory.build(request.flow) + leaf_nodes = dag.leaf_nodes + if len(leaf_nodes) != 1: + raise ValueError("Chat Flow just support one leaf node in dag") + task = cast(BaseOperator, leaf_nodes[0]) + dag_metadata = _parse_metadata(dag) + # TODO: Run task with variables + variables = request.variables + dag_request = request.request + + if isinstance(request.request, CommonLLMHttpRequestBody): + incremental = request.request.incremental + elif isinstance(request.request, dict): + incremental = request.request.get("incremental", False) + else: + raise ValueError("Invalid request type") + + if default_incremental is not None: + incremental = default_incremental + + try: + async for output in safe_chat_stream_with_dag_task( + task, dag_request, incremental + ): + yield output + except HTTPException as e: + yield ModelOutput(error_code=1, text=e.detail, incremental=incremental) + except Exception as e: + yield ModelOutput(error_code=1, text=str(e), incremental=incremental) + + async def _wrapper_chat_stream_flow_str( + self, stream_iter: AsyncIterator[ModelOutput] + ) -> AsyncIterator[str]: + + async for output in stream_iter: + text = output.text + if text: + text = text.replace("\n", "\\n") + if output.error_code != 0: + yield f"data:[SERVER_ERROR]{text}\n\n" + break + else: + yield f"data:{text}\n\n" diff --git a/dbgpt/serve/flow/service/share_utils.py b/dbgpt/serve/flow/service/share_utils.py new file mode 100644 index 000000000..99ba222a9 --- /dev/null +++ b/dbgpt/serve/flow/service/share_utils.py @@ -0,0 +1,121 @@ +import io +import json +import os +import tempfile +import zipfile + +import aiofiles +import tomlkit +from fastapi import UploadFile + +from dbgpt.component import SystemApp +from dbgpt.serve.core import blocking_func_to_async + +from ..api.schemas import ServeRequest + + +def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO: + + zip_buffer = io.BytesIO() + flow_name = flow.name + flow_label = flow.label + flow_description = flow.description + dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False) + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json" + readme = f"# {flow_label}\n\n{flow_description}" + zip_file.writestr(f"{package_name}/MANIFEST.in", manifest) + zip_file.writestr(f"{package_name}/README.md", readme) + zip_file.writestr( + f"{package_name}/{flow_name}/__init__.py", + "", + ) + zip_file.writestr( + f"{package_name}/{flow_name}/definition/flow_definition.json", + dag_json, + ) + dbgpts_toml = tomlkit.document() + # Add flow information + dbgpts_flow_toml = tomlkit.document() + dbgpts_flow_toml.add("label", "Simple Streaming Chat") + name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat") + name_with_comment.comment("A unique name for all dbgpts") + dbgpts_flow_toml.add("name", name_with_comment) + + dbgpts_flow_toml.add("version", "0.1.0") + dbgpts_flow_toml.add( + "description", + flow_description, + ) + dbgpts_flow_toml.add("authors", []) + + definition_type_with_comment = tomlkit.string("json") + definition_type_with_comment.comment("How to define the flow, python or json") + dbgpts_flow_toml.add("definition_type", definition_type_with_comment) + + dbgpts_toml.add("flow", dbgpts_flow_toml) + + # Add python and json config + python_config = tomlkit.table() + dbgpts_toml.add("python_config", python_config) + + json_config = tomlkit.table() + json_config.add("file_path", "definition/flow_definition.json") + json_config.comment("Json config") + + dbgpts_toml.add("json_config", json_config) + + # Transform to string + toml_string = tomlkit.dumps(dbgpts_toml) + zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string) + + pyproject_toml = tomlkit.document() + + # Add [tool.poetry] section + tool_poetry_toml = tomlkit.table() + tool_poetry_toml.add("name", package_name) + tool_poetry_toml.add("version", "0.1.0") + tool_poetry_toml.add("description", "A dbgpts package") + tool_poetry_toml.add("authors", []) + tool_poetry_toml.add("readme", "README.md") + pyproject_toml["tool"] = tomlkit.table() + pyproject_toml["tool"]["poetry"] = tool_poetry_toml + + # Add [tool.poetry.dependencies] section + dependencies = tomlkit.table() + dependencies.add("python", "^3.10") + pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies + + # Add [build-system] section + build_system = tomlkit.table() + build_system.add("requires", ["poetry-core"]) + build_system.add("build-backend", "poetry.core.masonry.api") + pyproject_toml["build-system"] = build_system + + # Transform to string + pyproject_toml_string = tomlkit.dumps(pyproject_toml) + zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string) + zip_buffer.seek(0) + return zip_buffer + + +async def _parse_flow_from_zip_file( + file: UploadFile, sys_app: SystemApp +) -> ServeRequest: + from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path + + filename = file.filename + if not filename.endswith(".zip"): + raise ValueError("Uploaded file must be a ZIP file") + + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, filename) + + # Save uploaded file to temporary directory + async with aiofiles.open(zip_path, "wb") as out_file: + while content := await file.read(1024 * 64): # Read in chunks of 64KB + await out_file.write(content) + flow = await blocking_func_to_async( + sys_app, _load_flow_package_from_zip_path, zip_path + ) + return flow diff --git a/dbgpt/serve/flow/service/variables_service.py b/dbgpt/serve/flow/service/variables_service.py new file mode 100644 index 000000000..09e2a16b0 --- /dev/null +++ b/dbgpt/serve/flow/service/variables_service.py @@ -0,0 +1,152 @@ +from typing import List, Optional + +from dbgpt import SystemApp +from dbgpt.core.interface.variables import StorageVariables, VariablesProvider +from dbgpt.serve.core import BaseService + +from ..api.schemas import VariablesRequest, VariablesResponse +from ..config import ( + SERVE_CONFIG_KEY_PREFIX, + SERVE_VARIABLES_SERVICE_COMPONENT_NAME, + ServeConfig, +) +from ..models.models import VariablesDao, VariablesEntity + + +class VariablesService( + BaseService[VariablesEntity, VariablesRequest, VariablesResponse] +): + """Variables service""" + + name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: VariablesDao = dao + + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or VariablesDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> VariablesDao: + """Returns the internal DAO.""" + return self._dao + + @property + def variables_provider(self) -> VariablesProvider: + """Returns the internal VariablesProvider. + + Returns: + VariablesProvider: The internal VariablesProvider + """ + variables_provider = VariablesProvider.get_instance( + self._system_app, default_component=None + ) + if variables_provider: + return variables_provider + else: + from ..serve import Serve + + variables_provider = Serve.get_instance(self._system_app).variables_provider + self._system_app.register_instance(variables_provider) + return variables_provider + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def create(self, request: VariablesRequest) -> VariablesResponse: + """Create a new entity + + Args: + request (VariablesRequest): The request + + Returns: + VariablesResponse: The response + """ + variables = StorageVariables( + key=request.key, + name=request.name, + label=request.label, + value=request.value, + value_type=request.value_type, + category=request.category, + scope=request.scope, + scope_key=request.scope_key, + user_name=request.user_name, + sys_code=request.sys_code, + enabled=1 if request.enabled else 0, + description=request.description, + ) + self.variables_provider.save(variables) + query = { + "key": request.key, + "name": request.name, + "scope": request.scope, + "scope_key": request.scope_key, + "sys_code": request.sys_code, + "user_name": request.user_name, + "enabled": request.enabled, + } + return self.dao.get_one(query) + + def update(self, _: int, request: VariablesRequest) -> VariablesResponse: + """Update variables. + + Args: + request (VariablesRequest): The request + + Returns: + VariablesResponse: The response + """ + variables = StorageVariables( + key=request.key, + name=request.name, + label=request.label, + value=request.value, + value_type=request.value_type, + category=request.category, + scope=request.scope, + scope_key=request.scope_key, + user_name=request.user_name, + sys_code=request.sys_code, + enabled=1 if request.enabled else 0, + description=request.description, + ) + exist_value = self.variables_provider.get( + variables.identifier.str_identifier, None + ) + if exist_value is None: + raise ValueError( + f"Variable {variables.identifier.str_identifier} not found" + ) + self.variables_provider.save(variables) + query = { + "key": request.key, + "name": request.name, + "scope": request.scope, + "scope_key": request.scope_key, + "sys_code": request.sys_code, + "user_name": request.user_name, + "enabled": request.enabled, + } + return self.dao.get_one(query) + + def list_all_variables(self, category: str = "common") -> List[VariablesResponse]: + """List all variables.""" + return self.dao.get_list({"enabled": True, "category": category}) diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py index 55f15cfed..bef4721ba 100644 --- a/dbgpt/storage/chat_history/chat_history_db.py +++ b/dbgpt/storage/chat_history/chat_history_db.py @@ -1,4 +1,5 @@ """Chat history database model.""" + import logging from datetime import datetime from typing import Optional @@ -56,7 +57,7 @@ class ChatHistoryEntity(Model): Index("idx_q_user", "user_name") Index("idx_q_mode", "chat_mode") Index("idx_q_conv", "summary") - Index("idx_app_code", "app_code") + Index("idx_chat_his_app_code", "app_code") class ChatHistoryMessageEntity(Model): diff --git a/dbgpt/util/dbgpts/loader.py b/dbgpt/util/dbgpts/loader.py index 4695db7d7..3693e5ed1 100644 --- a/dbgpt/util/dbgpts/loader.py +++ b/dbgpt/util/dbgpts/loader.py @@ -328,14 +328,19 @@ def _load_package_from_path(path: str): return parsed_packages -def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPackage: +def _load_flow_package_from_path( + name: str, path: str = INSTALL_DIR, filter_by_name: bool = True +) -> FlowPackage: raw_packages = _load_installed_package(path) new_name = name.replace("_", "-") - packages = [p for p in raw_packages if p.package == name or p.name == name] - if not packages: - packages = [ - p for p in raw_packages if p.package == new_name or p.name == new_name - ] + if filter_by_name: + packages = [p for p in raw_packages if p.package == name or p.name == name] + if not packages: + packages = [ + p for p in raw_packages if p.package == new_name or p.name == new_name + ] + else: + packages = raw_packages if not packages: raise ValueError(f"Can't find the package {name} or {new_name}") flow_package = _parse_package_metadata(packages[0]) @@ -344,6 +349,35 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack return cast(FlowPackage, flow_package) +def _load_flow_package_from_zip_path(zip_path: str) -> FlowPanel: + import tempfile + import zipfile + + with tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(temp_dir) + package_names = os.listdir(temp_dir) + if not package_names: + raise ValueError("No package found in the zip file") + if len(package_names) > 1: + raise ValueError("Only support one package in the zip file") + package_name = package_names[0] + with open( + Path(temp_dir) / package_name / INSTALL_METADATA_FILE, mode="w+" + ) as f: + # Write the metadata + import tomlkit + + install_metadata = { + "name": package_name, + "repo": "local/dbgpts", + } + tomlkit.dump(install_metadata, f) + + package = _load_flow_package_from_path("", path=temp_dir, filter_by_name=False) + return _flow_package_to_flow_panel(package) + + def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel: dict_value = { "name": package.name, @@ -353,6 +387,7 @@ def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel: "description": package.description, "source": package.repo, "define_type": "json", + "authors": package.authors, } if isinstance(package, FlowJsonPackage): dict_value["flow_data"] = package.read_definition_json() diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py new file mode 100644 index 000000000..8a1eb04b0 --- /dev/null +++ b/examples/awel/awel_flow_ui_components.py @@ -0,0 +1,1246 @@ +"""Some UI components for the AWEL flow.""" + +import json +import logging +from typing import Any, Dict, List, Optional + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + FunctionDynamicOptions, + IOField, + OperatorCategory, + OptionValue, + Parameter, + VariablesDynamicOptions, + ViewMetadata, + ui, +) +from dbgpt.core.interface.file import FileStorageClient +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, +) + +logger = logging.getLogger(__name__) + + +class ExampleFlowSelectOperator(MapOperator[str, str]): + """An example flow operator that includes a select as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Select", + name="example_flow_select", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a select as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UISelect(attr=ui.UISelect.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, self.fruits) + + +class ExampleFlowCascaderOperator(MapOperator[str, str]): + """An example flow operator that includes a cascader as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Cascader", + name="example_flow_cascader", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a cascader as parameter.", + parameters=[ + Parameter.build_from( + "Address Selector", + "address", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the address", + description="The address of the location.", + options=[ + OptionValue( + label="Zhejiang", + name="zhejiang", + value="zhejiang", + children=[ + OptionValue( + label="Hangzhou", + name="hangzhou", + value="hangzhou", + children=[ + OptionValue( + label="Xihu", + name="xihu", + value="xihu", + ), + OptionValue( + label="Feilaifeng", + name="feilaifeng", + value="feilaifeng", + ), + ], + ), + ], + ), + OptionValue( + label="Jiangsu", + name="jiangsu", + value="jiangsu", + children=[ + OptionValue( + label="Nanjing", + name="nanjing", + value="nanjing", + children=[ + OptionValue( + label="Zhonghua Gate", + name="zhonghuamen", + value="zhonghuamen", + ), + OptionValue( + label="Zhongshanling", + name="zhongshanling", + value="zhongshanling", + ), + ], + ), + ], + ), + ], + ui=ui.UICascader(attr=ui.UICascader.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Address", + "address", + str, + description="User's address.", + ) + ], + ) + + def __int__(self, address: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.address = address or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the address.""" + full_address_str = " ".join(self.address) + return "Your name is %s, and your address is %s." % ( + user_name, + full_address_str, + ) + + +class ExampleFlowCheckboxOperator(MapOperator[str, str]): + """An example flow operator that includes a checkbox as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Checkbox", + name="example_flow_checkbox", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a checkbox as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UICheckbox(), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, ", ".join(self.fruits)) + + +class ExampleFlowRadioOperator(MapOperator[str, str]): + """An example flow operator that includes a radio as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Radio", + name="example_flow_radio", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a radio as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UIRadio(), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, self.fruits) + + +class ExampleFlowDatePickerOperator(MapOperator[str, str]): + """An example flow operator that includes a date picker as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Date Picker", + name="example_flow_date_picker", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a date picker as parameter.", + parameters=[ + Parameter.build_from( + "Date Selector", + "date", + type=str, + placeholder="Select the date", + description="The date you choose.", + ui=ui.UIDatePicker( + attr=ui.UIDatePicker.UIAttribute(placement="bottomLeft") + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Date", + "date", + str, + description="User's selected date.", + ) + ], + ) + + def __init__(self, date: str, **kwargs): + super().__init__(**kwargs) + self.date = date + + async def map(self, user_name: str) -> str: + """Map the user name to the date.""" + return "Your name is %s, and you choose the date %s." % (user_name, self.date) + + +class ExampleFlowInputOperator(MapOperator[str, str]): + """An example flow operator that includes an input as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Input", + name="example_flow_input", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a input as parameter.", + parameters=[ + Parameter.build_from( + "Your hobby", + "hobby", + type=str, + placeholder="Please input your hobby", + description="The hobby you like.", + ui=ui.UIInput( + attr=ui.UIInput.UIAttribute( + prefix="icon:UserOutlined", show_count=True, maxlength=200 + ) + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "User Hobby", + "hobby", + str, + description="User's hobby.", + ) + ], + ) + + def __init__(self, hobby: str, **kwargs): + super().__init__(**kwargs) + self.hobby = hobby + + async def map(self, user_name: str) -> str: + """Map the user name to the input.""" + return "Your name is %s, and your hobby is %s." % (user_name, self.hobby) + + +class ExampleFlowTextAreaOperator(MapOperator[str, str]): + """An example flow operator that includes a text area as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Text Area", + name="example_flow_text_area", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a text area as parameter.", + parameters=[ + Parameter.build_from( + "Your comment", + "comment", + type=str, + placeholder="Please input your comment", + description="The comment you want to say.", + ui=ui.UITextArea( + attr=ui.UITextArea.UIAttribute( + show_count=True, + maxlength=1000, + auto_size=ui.UITextArea.UIAttribute.AutoSize( + min_rows=2, max_rows=6 + ), + ), + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "User Comment", + "comment", + str, + description="User's comment.", + ) + ], + ) + + def __init__(self, comment: str, **kwargs): + super().__init__(**kwargs) + self.comment = comment + + async def map(self, user_name: str) -> str: + """Map the user name to the text area.""" + return "Your name is %s, and your comment is %s." % (user_name, self.comment) + + +class ExampleFlowSliderOperator(MapOperator[float, float]): + + metadata = ViewMetadata( + label="Example Flow Slider", + name="example_flow_slider", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a slider as parameter.", + parameters=[ + Parameter.build_from( + "Default Temperature", + "default_temperature", + type=float, + optional=True, + default=0.7, + placeholder="Set the default temperature, e.g., 0.7", + description="The default temperature to pass to the LLM.", + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), + ) + ], + inputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature.", + ) + ], + outputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature to pass to the LLM.", + ) + ], + ) + + def __init__(self, default_temperature: float = 0.7, **kwargs): + super().__init__(**kwargs) + self.default_temperature = default_temperature + + async def map(self, temperature: float) -> float: + """Map the temperature to the result.""" + if temperature < 0.0 or temperature > 2.0: + logger.warning("Temperature out of range: %s", temperature) + return self.default_temperature + else: + return temperature + + +class ExampleFlowSliderListOperator(MapOperator[float, float]): + """An example flow operator that includes a slider list as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Slider List", + name="example_flow_slider_list", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a slider list as parameter.", + parameters=[ + Parameter.build_from( + "Temperature Selector", + "temperature_range", + type=float, + is_list=True, + optional=True, + default=None, + placeholder="Set the temperature, e.g., [0.1, 0.9]", + description="The temperature range to pass to the LLM.", + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), + ) + ], + inputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature.", + ) + ], + outputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature to pass to the LLM.", + ) + ], + ) + + def __init__(self, temperature_range: Optional[List[float]] = None, **kwargs): + super().__init__(**kwargs) + temperature_range = temperature_range or [0.1, 0.9] + if temperature_range and len(temperature_range) != 2: + raise ValueError("The length of temperature range must be 2.") + self.temperature_range = temperature_range + + async def map(self, temperature: float) -> float: + """Map the temperature to the result.""" + min_temperature, max_temperature = self.temperature_range + if temperature < min_temperature or temperature > max_temperature: + logger.warning( + "Temperature out of range: %s, min: %s, max: %s", + temperature, + min_temperature, + max_temperature, + ) + return min_temperature + return temperature + + +class ExampleFlowTimePickerOperator(MapOperator[str, str]): + """An example flow operator that includes a time picker as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Time Picker", + name="example_flow_time_picker", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a time picker as parameter.", + parameters=[ + Parameter.build_from( + "Time Selector", + "time", + type=str, + placeholder="Select the time", + description="The time you choose.", + ui=ui.UITimePicker( + attr=ui.UITimePicker.UIAttribute( + format="HH:mm:ss", hour_step=2, minute_step=10, second_step=10 + ), + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Time", + "time", + str, + description="User's selected time.", + ) + ], + ) + + def __init__(self, time: str, **kwargs): + super().__init__(**kwargs) + self.time = time + + async def map(self, user_name: str) -> str: + """Map the user name to the time.""" + return "Your name is %s, and you choose the time %s." % (user_name, self.time) + + +class ExampleFlowTreeSelectOperator(MapOperator[str, str]): + """An example flow operator that includes a tree select as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Tree Select", + name="example_flow_tree_select", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a tree select as parameter.", + parameters=[ + Parameter.build_from( + "Address Selector", + "address", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the address", + description="The address of the location.", + options=[ + OptionValue( + label="Zhejiang", + name="zhejiang", + value="zhejiang", + children=[ + OptionValue( + label="Hangzhou", + name="hangzhou", + value="hangzhou", + children=[ + OptionValue( + label="Xihu", + name="xihu", + value="xihu", + ), + OptionValue( + label="Feilaifeng", + name="feilaifeng", + value="feilaifeng", + ), + ], + ), + ], + ), + OptionValue( + label="Jiangsu", + name="jiangsu", + value="jiangsu", + children=[ + OptionValue( + label="Nanjing", + name="nanjing", + value="nanjing", + children=[ + OptionValue( + label="Zhonghua Gate", + name="zhonghuamen", + value="zhonghuamen", + ), + OptionValue( + label="Zhongshanling", + name="zhongshanling", + value="zhongshanling", + ), + ], + ), + ], + ), + ], + ui=ui.UITreeSelect(attr=ui.UITreeSelect.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Address", + "address", + str, + description="User's address.", + ) + ], + ) + + def __int__(self, address: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.address = address or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the address.""" + full_address_str = " ".join(self.address) + return "Your name is %s, and your address is %s." % ( + user_name, + full_address_str, + ) + + +def get_recent_3_times(time_interval: int = 1) -> List[OptionValue]: + """Get the recent times.""" + from datetime import datetime, timedelta + + now = datetime.now() + recent_times = [now - timedelta(hours=time_interval * i) for i in range(3)] + formatted_times = [time.strftime("%Y-%m-%d %H:%M:%S") for time in recent_times] + option_values = [ + OptionValue(label=formatted_time, name=f"time_{i + 1}", value=formatted_time) + for i, formatted_time in enumerate(formatted_times) + ] + + return option_values + + +class ExampleFlowRefreshOperator(MapOperator[str, str]): + """An example flow operator that includes a refresh option.""" + + metadata = ViewMetadata( + label="Example Refresh Operator", + name="example_refresh_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a refresh option.", + parameters=[ + Parameter.build_from( + "Time Interval", + "time_interval", + type=int, + optional=True, + default=1, + placeholder="Set the time interval", + description="The time interval to fetch the times", + ), + Parameter.build_from( + "Recent Time", + "recent_time", + type=str, + optional=True, + default=None, + placeholder="Select the recent time", + description="The recent time to choose.", + options=FunctionDynamicOptions(func=get_recent_3_times), + ui=ui.UISelect( + refresh=True, + refresh_depends=["time_interval"], + attr=ui.UISelect.UIAttribute(show_search=True), + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Time", + "time", + str, + description="User's selected time.", + ) + ], + ) + + def __init__( + self, time_interval: int = 1, recent_time: Optional[str] = None, **kwargs + ): + super().__init__(**kwargs) + self.time_interval = time_interval + self.recent_time = recent_time + + async def map(self, user_name: str) -> str: + """Map the user name to the time.""" + return "Your name is %s, and you choose the time %s." % ( + user_name, + self.recent_time, + ) + + +class ExampleFlowUploadOperator(MapOperator[str, str]): + """An example flow operator that includes an upload as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Upload", + name="example_flow_upload", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a upload as parameter.", + parameters=[ + Parameter.build_from( + "Single File Selector", + "file", + type=str, + optional=True, + default=None, + placeholder="Select the file", + description="The file you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="after_select", + attr=ui.UIUpload.UIAttribute(max_count=1), + ), + ), + Parameter.build_from( + "Multiple Files Selector", + "multiple_files", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the multiple files", + description="The multiple files you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="button_click", + attr=ui.UIUpload.UIAttribute(max_count=5), + ), + ), + Parameter.build_from( + "CSV File Selector", + "csv_file", + type=str, + optional=True, + default=None, + placeholder="Select the CSV file", + description="The CSV file you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="after_select", + file_types=[".csv"], + attr=ui.UIUpload.UIAttribute(max_count=1), + ), + ), + Parameter.build_from( + "Images Selector", + "images", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the images", + description="The images you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="button_click", + file_types=["image/*", "*.pdf"], + drag=True, + attr=ui.UIUpload.UIAttribute(max_count=5), + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "File", + "file", + str, + description="User's uploaded file.", + ) + ], + ) + + def __init__( + self, + file: Optional[str] = None, + multiple_files: Optional[List[str]] = None, + csv_file: Optional[str] = None, + images: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.file = file + self.multiple_files = multiple_files or [] + self.csv_file = csv_file + self.images = images or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the file.""" + + fsc = FileStorageClient.get_instance(self.system_app) + files_metadata = await self.blocking_func_to_async( + self._parse_files_metadata, fsc + ) + files_metadata_str = json.dumps(files_metadata, ensure_ascii=False) + return "Your name is %s, and you files are %s." % ( + user_name, + files_metadata_str, + ) + + def _parse_files_metadata(self, fsc: FileStorageClient) -> List[Dict[str, Any]]: + """Parse the files metadata.""" + if not self.file: + raise ValueError("The file is not uploaded.") + if not self.multiple_files: + raise ValueError("The multiple files are not uploaded.") + files = [self.file] + self.multiple_files + [self.csv_file] + self.images + results = [] + for file in files: + _, metadata = fsc.get_file(file) + results.append( + { + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_size": metadata.file_size, + "storage_type": metadata.storage_type, + "uri": metadata.uri, + "file_hash": metadata.file_hash, + } + ) + return results + + +class ExampleFlowVariablesOperator(MapOperator[str, str]): + """An example flow operator that includes a variables option.""" + + metadata = ViewMetadata( + label="Example Variables Operator", + name="example_variables_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a variables option.", + parameters=[ + Parameter.build_from( + "OpenAI API Key", + "openai_api_key", + type=str, + placeholder="Please select the OpenAI API key", + description="The OpenAI API key to use.", + options=VariablesDynamicOptions(), + ui=ui.UIPasswordInput( + key="dbgpt.model.openai.api_key", + ), + ), + Parameter.build_from( + "Model", + "model", + type=str, + placeholder="Please select the model", + description="The model to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.model.openai.model", + ), + ), + Parameter.build_from( + "Builtin Flows", + "builtin_flow", + type=str, + placeholder="Please select the builtin flows", + description="The builtin flows to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_FLOWS, + ), + ), + Parameter.build_from( + "Builtin Flow Nodes", + "builtin_flow_node", + type=str, + placeholder="Please select the builtin flow nodes", + description="The builtin flow nodes to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_FLOW_NODES, + ), + ), + Parameter.build_from( + "Builtin Variables", + "builtin_variable", + type=str, + placeholder="Please select the builtin variables", + description="The builtin variables to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_VARIABLES, + ), + ), + Parameter.build_from( + "Builtin Secrets", + "builtin_secret", + type=str, + placeholder="Please select the builtin secrets", + description="The builtin secrets to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_SECRETS, + ), + ), + Parameter.build_from( + "Builtin LLMs", + "builtin_llm", + type=str, + placeholder="Please select the builtin LLMs", + description="The builtin LLMs to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_LLMS, + ), + ), + Parameter.build_from( + "Builtin Embeddings", + "builtin_embedding", + type=str, + placeholder="Please select the builtin embeddings", + description="The builtin embeddings to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_EMBEDDINGS, + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Model info", + "model", + str, + description="The model info.", + ), + ], + ) + + def __init__( + self, + openai_api_key: str, + model: str, + builtin_flow: str, + builtin_flow_node: str, + builtin_variable: str, + builtin_secret: str, + builtin_llm: str, + builtin_embedding: str, + **kwargs, + ): + super().__init__(**kwargs) + self.openai_api_key = openai_api_key + self.model = model + self.builtin_flow = builtin_flow + self.builtin_flow_node = builtin_flow_node + self.builtin_variable = builtin_variable + self.builtin_secret = builtin_secret + self.builtin_llm = builtin_llm + self.builtin_embedding = builtin_embedding + + async def map(self, user_name: str) -> str: + """Map the user name to the model.""" + dict_dict = { + "openai_api_key": self.openai_api_key, + "model": self.model, + "builtin_flow": self.builtin_flow, + "builtin_flow_node": self.builtin_flow_node, + "builtin_variable": self.builtin_variable, + "builtin_secret": self.builtin_secret, + "builtin_llm": self.builtin_llm, + "builtin_embedding": self.builtin_embedding, + } + json_data = json.dumps(dict_dict, ensure_ascii=False) + return "Your name is %s, and your model info is %s." % (user_name, json_data) + + +class ExampleFlowTagsOperator(MapOperator[str, str]): + """An example flow operator that includes a tags option.""" + + metadata = ViewMetadata( + label="Example Tags Operator", + name="example_tags_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a tags", + parameters=[], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Tags", + "tags", + str, + description="The tags to use.", + ), + ], + tags={"order": "higher-order", "type": "example"}, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, user_name: str) -> str: + """Map the user name to the tags.""" + return "Your name is %s, and your tags are %s." % (user_name, "higher-order") + + +class ExampleFlowCodeEditorOperator(MapOperator[str, str]): + """An example flow operator that includes a code editor as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Code Editor", + name="example_flow_code_editor", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a code editor as parameter.", + parameters=[ + Parameter.build_from( + "Code Editor", + "code", + type=str, + placeholder="Please input your code", + description="The code you want to edit.", + ui=ui.UICodeEditor( + language="python", + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Code", + "code", + str, + description="Result of the code.", + ) + ], + ) + + def __init__(self, code: str, **kwargs): + super().__init__(**kwargs) + self.code = code + + async def map(self, user_name: str) -> str: + """Map the user name to the code.""" + from dbgpt.util.code_utils import UNKNOWN, extract_code + + code = self.code + exitcode = -1 + try: + code_blocks = extract_code(self.code) + if len(code_blocks) < 1: + logger.info( + f"No executable code found in: \n{code}", + ) + raise ValueError(f"No executable code found in: \n{code}") + elif len(code_blocks) > 1 and code_blocks[0][0] == UNKNOWN: + # found code blocks, execute code and push "last_n_messages" back + logger.info( + f"Missing available code block type, unable to execute code," + f"\n{code}", + ) + raise ValueError( + "Missing available code block type, unable to execute code, " + f"\n{code}" + ) + exitcode, logs = await self.blocking_func_to_async( + self.execute_code_blocks, code_blocks + ) + # exitcode, logs = self.execute_code_blocks(code_blocks) + except Exception as e: + logger.error(f"Failed to execute code: {e}") + logs = f"Failed to execute code: {e}" + return ( + f"Your name is {user_name}, and your code is \n\n```python\n{self.code}" + f"\n\n```\n\nThe execution result is \n\n```\n{logs}\n\n```\n\n" + f"Exit code: {exitcode}." + ) + + def execute_code_blocks(self, code_blocks): + """Execute the code blocks and return the result.""" + from dbgpt.util.code_utils import execute_code, infer_lang + from dbgpt.util.utils import colored + + logs_all = "" + exitcode = -1 + _code_execution_config = {"use_docker": False} + for i, code_block in enumerate(code_blocks): + lang, code = code_block + if not lang: + lang = infer_lang(code) + print( + colored( + f"\n>>>>>>>> EXECUTING CODE BLOCK {i} " + f"(inferred language is {lang})...", + "red", + ), + flush=True, + ) + if lang in ["bash", "shell", "sh"]: + exitcode, logs, image = execute_code( + code, lang=lang, **_code_execution_config + ) + elif lang in ["python", "Python"]: + if code.startswith("# filename: "): + filename = code[11 : code.find("\n")].strip() + else: + filename = None + exitcode, logs, image = execute_code( + code, + lang="python", + filename=filename, + **_code_execution_config, + ) + else: + # In case the language is not supported, we return an error message. + exitcode, logs, image = ( + 1, + f"unknown language {lang}", + None, + ) + # raise NotImplementedError + if image is not None: + _code_execution_config["use_docker"] = image + logs_all += "\n" + logs + if exitcode != 0: + return exitcode, logs_all + return exitcode, logs_all diff --git a/setup.py b/setup.py index cbe5592ce..a968892df 100644 --- a/setup.py +++ b/setup.py @@ -498,6 +498,8 @@ def core_requires(): "GitPython", # For AWEL dag visualization, graphviz is a small package, also we can move it to default. "graphviz", + # For security + "cryptography", ]