From 3bfbdc3fc7d78937d0fff0272052413fc3832e8f Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 31 Mar 2025 09:38:31 +0800 Subject: [PATCH] feat(agent): More general ReAct Agent (#2556) --- Makefile | 4 + assets/schema/dbgpt.sql | 2 +- .../upgrade/v0_7_1/upgrade_to_v0.7.1.sql | 6 + assets/schema/upgrade/v0_7_1/v0.7.0.sql | 582 ++++++++++++++++++ examples/agents/react_agent_example.py | 37 +- install_help.py | 9 +- .../src/dbgpt_app/component_configs.py | 2 + .../src/dbgpt/agent/core/action/base.py | 40 +- .../src/dbgpt/agent/core/agent_manage.py | 45 +- .../src/dbgpt/agent/core/base_agent.py | 121 +++- .../dbgpt/agent/core/memory/agent_memory.py | 11 + .../src/dbgpt/agent/core/memory/base.py | 26 +- .../src/dbgpt/agent/core/memory/gpts/base.py | 8 + .../core/memory/gpts/default_gpts_memory.py | 4 + .../agent/core/memory/gpts/gpts_memory.py | 65 +- .../src/dbgpt/agent/core/memory/hybrid.py | 48 +- .../src/dbgpt/agent/core/memory/long_term.py | 171 ++++- .../src/dbgpt/agent/core/memory/short_term.py | 112 ++-- .../dbgpt-core/src/dbgpt/agent/core/role.py | 35 +- .../agent/expand/actions/react_action.py | 201 +++++- .../dbgpt/agent/expand/actions/tool_action.py | 107 ++-- .../src/dbgpt/agent/expand/react_agent.py | 418 ++++++------- .../src/dbgpt/agent/resource/base.py | 21 +- .../src/dbgpt/agent/resource/pack.py | 38 +- .../src/dbgpt/agent/resource/tool/pack.py | 30 +- .../src/dbgpt/agent/util/conv_utils.py | 17 + .../src/dbgpt/agent/util/react_parser.py | 212 +++++++ .../agent/util/tests/test_react_parser.py | 404 ++++++++++++ .../src/dbgpt/rag/retriever/time_weighted.py | 293 ++++++++- .../dbgpt_serve/agent/agents/controller.py | 28 +- .../agent/agents/db_gpts_memory.py | 4 +- .../dbgpt_serve/agent/db/gpts_messages_db.py | 20 +- scripts/update_version_all.py | 321 +++++----- 33 files changed, 2759 insertions(+), 683 deletions(-) create mode 100644 assets/schema/upgrade/v0_7_1/upgrade_to_v0.7.1.sql create mode 100644 assets/schema/upgrade/v0_7_1/v0.7.0.sql create mode 100644 packages/dbgpt-core/src/dbgpt/agent/util/conv_utils.py create mode 100644 packages/dbgpt-core/src/dbgpt/agent/util/react_parser.py create mode 100644 packages/dbgpt-core/src/dbgpt/agent/util/tests/test_react_parser.py diff --git a/Makefile b/Makefile index d54c1e1d1..a6c7163cd 100644 --- a/Makefile +++ b/Makefile @@ -53,10 +53,14 @@ fmt: setup ## Format Python code $(VENV_BIN)/ruff format packages $(VENV_BIN)/ruff format --exclude="examples/notebook" examples $(VENV_BIN)/ruff format i18n + $(VENV_BIN)/ruff format scripts/update_version_all.py + $(VENV_BIN)/ruff format install_help.py # Sort imports $(VENV_BIN)/ruff check --select I --fix packages $(VENV_BIN)/ruff check --select I --fix --exclude="examples/notebook" examples $(VENV_BIN)/ruff check --select I --fix i18n + $(VENV_BIN)/ruff check --select I --fix update_version_all.py + $(VENV_BIN)/ruff check --select I --fix install_help.py $(VENV_BIN)/ruff check --fix packages \ --exclude="packages/dbgpt-serve/src/**" diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index be343f65c..f2b3cdbb2 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -259,7 +259,7 @@ CREATE TABLE `gpts_messages` ( `current_goal` text COMMENT 'The target corresponding to the current message', `context` text COMMENT 'Current conversation context', `review_info` text COMMENT 'Current conversation review info', - `action_report` text COMMENT 'Current conversation action report', + `action_report` longtext COMMENT 'Current conversation action report', `resource_info` text DEFAULT NULL COMMENT 'Current conversation resource info', `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', `created_at` datetime DEFAULT NULL COMMENT 'create time', diff --git a/assets/schema/upgrade/v0_7_1/upgrade_to_v0.7.1.sql b/assets/schema/upgrade/v0_7_1/upgrade_to_v0.7.1.sql new file mode 100644 index 000000000..7d1d4b7e7 --- /dev/null +++ b/assets/schema/upgrade/v0_7_1/upgrade_to_v0.7.1.sql @@ -0,0 +1,6 @@ +-- From 0.7.0 to 0.7.1, we have the following changes: +USE dbgpt; + +-- Change message_detail column type from text to longtext in chat_history_message table +ALTER TABLE `gpts_messages` + MODIFY COLUMN `action_report` longtext COMMENT 'Current conversation action report'; diff --git a/assets/schema/upgrade/v0_7_1/v0.7.0.sql b/assets/schema/upgrade/v0_7_1/v0.7.0.sql new file mode 100644 index 000000000..ec6e0ebba --- /dev/null +++ b/assets/schema/upgrade/v0_7_1/v0.7.0.sql @@ -0,0 +1,582 @@ +-- Full SQL of v0.7.0, please not modify this file(It must be same as the file in the release package) + +CREATE +DATABASE IF NOT EXISTS dbgpt; +use dbgpt; + +-- For alembic migration tool +CREATE TABLE IF NOT EXISTS `alembic_version` +( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) +) DEFAULT CHARSET=utf8mb4 ; + +CREATE TABLE IF NOT EXISTS `knowledge_space` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `domain_type` varchar(50) NOT NULL COMMENT 'domain type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `context` TEXT DEFAULT NULL COMMENT 'context argument', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; + +CREATE TABLE IF NOT EXISTS `knowledge_document` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `doc_token` varchar(100) NULL COMMENT 'doc token', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `questions` TEXT NULL COMMENT 'document related questions', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; + +CREATE TABLE IF NOT EXISTS `document_chunk` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `questions` text NULL COMMENT 'chunk related questions', + `meta_info` text NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; + + +CREATE TABLE IF NOT EXISTS `connect_config` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `db_type` varchar(255) NOT NULL COMMENT 'db type', + `db_name` varchar(255) NOT NULL COMMENT 'db name', + `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', + `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', + `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', + `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', + `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', + `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `user_name` varchar(255) DEFAULT NULL COMMENT 'user name', + `user_id` varchar(255) DEFAULT NULL COMMENT 'user id', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + `ext_config` text COMMENT 'Extended configuration, json format', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_db` (`db_name`), + KEY `idx_q_db_type` (`db_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; + +CREATE TABLE IF NOT EXISTS `chat_history` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', + `summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', + `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `app_code` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'App unique code', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `conv_uid` (`conv_uid`), + PRIMARY KEY (`id`), + KEY `idx_chat_his_app_code` (`app_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; + +CREATE TABLE IF NOT EXISTS `chat_history_message` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `index` int NOT NULL COMMENT 'Message index', + `round_index` int NOT NULL COMMENT 'Round of conversation', + `message_detail` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `message_uid_index` (`conv_uid`, `index`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + + +CREATE TABLE IF NOT EXISTS `chat_feed_back` +( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', + `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', + `score` int(1) DEFAULT NULL COMMENT 'Score of user', + `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', + `question` longtext DEFAULT NULL COMMENT 'User question', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', + `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `message_id` varchar(255) NULL COMMENT 'Message id', + `feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike', + `reason_types` varchar(255) NULL COMMENT 'Feedback reason categories', + `remark` text NULL COMMENT 'Feedback remark', + `user_code` varchar(128) NULL COMMENT 'User code', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; + + +CREATE TABLE IF NOT EXISTS `my_plugin` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', + `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; + +CREATE TABLE IF NOT EXISTS `plugin_hub` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', + `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; + + +CREATE TABLE IF NOT EXISTS `prompt_manage` +( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) DEFAULT NULL COMMENT 'prompt name', + `prompt_code` varchar(256) DEFAULT NULL COMMENT 'prompt code', + `content` longtext COMMENT 'Prompt content', + `input_variables` varchar(1024) DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `response_schema` text DEFAULT NULL COMMENT 'Prompt response schema', + `model` varchar(128) DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) DEFAULT NULL COMMENT 'Prompt description', + `user_code` varchar(128) DEFAULT NULL COMMENT 'User code', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + + + CREATE TABLE IF NOT EXISTS `gpts_conversations` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `user_goal` text NOT NULL COMMENT 'User''s goals content', + `gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name', + `state` varchar(255) DEFAULT NULL COMMENT 'The gpts state', + `max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round', + `auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NULL COMMENT 'agent team work mode', + + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_conversations` (`conv_id`), + KEY `idx_gpts_name` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations"; + +CREATE TABLE IF NOT EXISTS `gpts_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `resource_db` text COMMENT 'List of structured database names contained in the current gpts', + `resource_internet` text COMMENT 'Is it possible to retrieve information from the internet', + `resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts', + `gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts', + `gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts', + `language` varchar(100) DEFAULT NULL COMMENT 'gpts language', + `user_code` varchar(255) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance"; + +CREATE TABLE `gpts_messages` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn', + `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', + `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', + `rounds` int(11) NOT NULL COMMENT 'dialogue turns', + `is_success` int(4) NULL DEFAULT 0 COMMENT 'agent message is success', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `content` text COMMENT 'Content of the speech', + `current_goal` text COMMENT 'The target corresponding to the current message', + `context` text COMMENT 'Current conversation context', + `review_info` text COMMENT 'Current conversation review info', + `action_report` text COMMENT 'Current conversation action report', + `resource_info` text DEFAULT NULL COMMENT 'Current conversation resource info', + `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message"; + + +CREATE TABLE `gpts_plans` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sub_task_num` int(11) NOT NULL COMMENT 'Subtask number', + `sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title', + `sub_task_content` text NOT NULL COMMENT 'subtask content', + `sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks', + `resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name', + `rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3', + `agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents', + `retry_times` int(11) DEFAULT NULL COMMENT 'number of retries', + `max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries', + `state` varchar(255) DEFAULT NULL COMMENT 'subtask status', + `result` longtext COMMENT 'subtask result', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan"; + +-- dbgpt.dbgpt_serve_flow definition +CREATE TABLE `dbgpt_serve_flow` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `uid` varchar(128) NOT NULL COMMENT 'Unique id', + `dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id', + `name` varchar(128) DEFAULT NULL COMMENT 'Flow name', + `flow_data` text COMMENT 'Flow data, JSON format', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time', + `flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category', + `description` varchar(512) DEFAULT NULL COMMENT 'Flow description', + `state` varchar(32) DEFAULT NULL COMMENT 'Flow state', + `error_message` varchar(512) NULL comment 'Error message', + `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', + `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', + `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', + `define_type` varchar(32) null comment 'Flow define type(json or python)', + `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', + `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', + `variables` text DEFAULT NULL COMMENT 'Flow variables, JSON format', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`), + KEY `ix_dbgpt_serve_flow_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`), + KEY `ix_dbgpt_serve_flow_user_name` (`user_name`), + KEY `ix_dbgpt_serve_flow_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, ""etc)', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow_priv", the scope_key is dag id of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `description` text DEFAULT NULL COMMENT 'Variable description', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE TABLE `dbgpt_serve_model` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `host` varchar(255) NOT NULL COMMENT 'The model worker host', + `port` int NOT NULL COMMENT 'The model worker port', + `model` varchar(255) NOT NULL COMMENT 'The model name', + `provider` varchar(255) NOT NULL COMMENT 'The model provider', + `worker_type` varchar(255) NOT NULL COMMENT 'The worker type', + `params` text NOT NULL COMMENT 'The model parameters, JSON format', + `enabled` int DEFAULT 1 COMMENT 'Whether the model is enabled, if it is enabled, it will be started when the system starts, 1 is enabled, 0 is disabled', + `worker_name` varchar(255) DEFAULT NULL COMMENT 'The worker name', + `description` text DEFAULT NULL COMMENT 'The model description', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `idx_user_name` (`user_name`), + KEY `idx_sys_code` (`sys_code`), + UNIQUE KEY `uk_model_provider_type` (`model`, `provider`, `worker_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Model persistence table'; + +-- dbgpt.gpts_app definition +CREATE TABLE `gpts_app` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `language` varchar(100) NOT NULL COMMENT 'gpts language', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', + `published` varchar(64) DEFAULT 'false' COMMENT 'Has it been published?', + `param_need` text DEFAULT NULL COMMENT 'Parameter information supported by the application', + `admins` text DEFAULT NULL COMMENT 'administrator', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app` (`app_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE TABLE `gpts_app_collection` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `user_code` int(11) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_app_code` (`app_code`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections"; + +-- dbgpt.gpts_app_detail definition +CREATE TABLE `gpts_app_detail` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `agent_name` varchar(255) NOT NULL COMMENT ' Agent name', + `node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id', + `resources` text COMMENT 'Agent bind resource', + `prompt_template` text COMMENT 'Agent bind template', + `llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy', + `llm_strategy_value` text COMMENT 'Agent use llm strategy value', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + +-- For deploy model cluster of DB-GPT(StorageModelRegistry) +CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `model_name` varchar(128) NOT NULL COMMENT 'Model name', + `host` varchar(128) NOT NULL COMMENT 'Host of the model', + `port` int(11) NOT NULL COMMENT 'Port of the model', + `weight` float DEFAULT 1.0 COMMENT 'Weight of the model', + `check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model', + `healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy', + `enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled', + `prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance', + `last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; + +-- dbgpt.recommend_question definition +CREATE TABLE `recommend_question` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', + `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `question` text DEFAULT NULL COMMENT 'question', + `user_code` varchar(255) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NULL COMMENT 'system app code', + `valid` varchar(10) DEFAULT 'true' COMMENT 'is it effective,true/false', + `chat_mode` varchar(255) DEFAULT NULL COMMENT 'Conversation scene mode,chat_knowledge...', + `params` text DEFAULT NULL COMMENT 'question param', + `is_hot_question` varchar(10) DEFAULT 'false' COMMENT 'Is it a popular recommendation question?', + PRIMARY KEY (`id`), + KEY `idx_rec_q_app_code` (`app_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT="AI application related recommendation issues"; + +-- dbgpt.user_recent_apps definition +CREATE TABLE `user_recent_apps` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time', + `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time', + `app_code` varchar(255) NOT NULL COMMENT 'AI assistant code', + `last_accessed` timestamp NULL DEFAULT NULL COMMENT 'User recent usage time', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + PRIMARY KEY (`id`), + KEY `idx_user_r_app_code` (`app_code`), + KEY `idx_last_accessed` (`last_accessed`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User recently used apps'; + +-- dbgpt.dbgpt_serve_dbgpts_my definition +CREATE TABLE `dbgpt_serve_dbgpts_my` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) NOT NULL COMMENT 'plugin name', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `user_name` varchar(255) DEFAULT NULL COMMENT 'user name', + `file_name` varchar(255) NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`, `user_name`), + KEY `ix_my_plugin_sys_code` (`sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_dbgpts_hub definition +CREATE TABLE `dbgpt_serve_dbgpts_hub` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) NOT NULL COMMENT 'plugin name', + `description` varchar(255) NULL COMMENT 'plugin description', + `author` varchar(255) DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + +CREATE +DATABASE IF NOT EXISTS EXAMPLE_1; +use EXAMPLE_1; +CREATE TABLE IF NOT EXISTS `users` +( + `id` int NOT NULL AUTO_INCREMENT, + `username` varchar(50) NOT NULL COMMENT '用户名', + `password` varchar(50) NOT NULL COMMENT '密码', + `email` varchar(50) NOT NULL COMMENT '邮箱', + `phone` varchar(20) DEFAULT NULL COMMENT '电话', + PRIMARY KEY (`id`), + KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; + +INSERT INTO users (username, password, email, phone) +VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file diff --git a/examples/agents/react_agent_example.py b/examples/agents/react_agent_example.py index 47c85a817..9aaf0ff23 100644 --- a/examples/agents/react_agent_example.py +++ b/examples/agents/react_agent_example.py @@ -6,6 +6,7 @@ import sys from typing_extensions import Annotated, Doc from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent +from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate from dbgpt.agent.expand.react_agent import ReActAgent from dbgpt.agent.resource import ToolPack, tool @@ -16,17 +17,13 @@ logging.basicConfig( ) -@tool -def terminate( - final_answer: Annotated[str, Doc("final literal answer about the goal")], -) -> str: - """When the goal achieved, this tool must be called.""" - return final_answer - - @tool def simple_calculator(first_number: int, second_number: int, operator: str) -> float: - """Simple calculator tool. Just support +, -, *, /.""" + """Simple calculator tool. Just support +, -, *, /. + When users need to do numerical calculations, you must use this tool to calculate, \ + and you are not allowed to directly infer calculation results from user input or \ + external observations. + """ if isinstance(first_number, str): first_number = int(first_number) if isinstance(second_number, str): @@ -52,22 +49,28 @@ def count_directory_files(path: Annotated[str, Doc("The directory path")]) -> in async def main(): - from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient + from dbgpt.model import AutoLLMClient - llm_client = SiliconFlowLLMClient( - model_alias="Qwen/Qwen2-7B-Instruct", + llm_client = AutoLLMClient( + # provider=os.getenv("LLM_PROVIDER", "proxy/deepseek"), + # name=os.getenv("LLM_MODEL_NAME", "deepseek-chat"), + provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"), + name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"), ) agent_memory = AgentMemory() agent_memory.gpts_memory.init(conv_id="test456") - context: AgentContext = AgentContext(conv_id="test456", gpts_app_name="ReAct") + # It is important to set the temperature to a low value to get a better result + context: AgentContext = AgentContext( + conv_id="test456", gpts_app_name="ReAct", temperature=0.01 + ) - tools = ToolPack([simple_calculator, count_directory_files, terminate]) + tools = ToolPack([simple_calculator, count_directory_files, Terminate()]) user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build() tool_engineer = ( - await ReActAgent(end_action_name="terminate", max_steps=10) + await ReActAgent(max_retry_count=10) .bind(context) .bind(LLMConfig(llm_client=llm_client)) .bind(agent_memory) @@ -78,7 +81,9 @@ async def main(): await user_proxy.initiate_chat( recipient=tool_engineer, reviewer=user_proxy, - message="Calculate the product of 10 and 99, Count the number of files in /tmp, answer in Chinese.", + message="Calculate the product of 10 and 99, then count the number of files in /tmp", + # message="Calculate the product of 10 and 99", + # message="Count the number of files in /tmp", ) # dbgpt-vis message infos diff --git a/install_help.py b/install_help.py index 958ac524a..45b12947e 100755 --- a/install_help.py +++ b/install_help.py @@ -8,13 +8,14 @@ # [tool.uv] # exclude-newer = "2025-03-07T00:00:00Z" # /// -import os -import tomli import glob +import os +from pathlib import Path +from typing import Any, Dict + import click import inquirer -from pathlib import Path -from typing import Dict, Any +import tomli # For I18N support, we use a simple class to store translations and a global instance diff --git a/packages/dbgpt-app/src/dbgpt_app/component_configs.py b/packages/dbgpt-app/src/dbgpt_app/component_configs.py index 7e513075d..482c2da6d 100644 --- a/packages/dbgpt-app/src/dbgpt_app/component_configs.py +++ b/packages/dbgpt-app/src/dbgpt_app/component_configs.py @@ -96,6 +96,7 @@ def _initialize_agent(system_app: SystemApp): def _initialize_resource_manager(system_app: SystemApp): + from dbgpt.agent.expand.actions.react_action import Terminate from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models from dbgpt.agent.expand.resources.host_tool import ( get_current_host_cpu_status, @@ -117,6 +118,7 @@ def _initialize_resource_manager(system_app: SystemApp): rm.register_resource(KnowledgeSpaceRetrieverResource) rm.register_resource(PluginToolPack, resource_type=ResourceType.Tool) rm.register_resource(GptAppResource) + rm.register_resource(resource_instance=Terminate()) # Register a search tool rm.register_resource(resource_instance=baidu_search) rm.register_resource(resource_instance=list_dbgpt_support_models) diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/action/base.py b/packages/dbgpt-core/src/dbgpt/agent/core/action/base.py index be8059c84..6cf401216 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/action/base.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/action/base.py @@ -46,12 +46,20 @@ class ActionOutput(BaseModel): resource_type: Optional[str] = None resource_value: Optional[Any] = None action: Optional[str] = None + action_input: Optional[str] = None thoughts: Optional[str] = None observations: Optional[str] = None have_retry: Optional[bool] = True ask_user: Optional[bool] = False # 如果当前agent能确定下个发言者,需要在这里指定 next_speakers: Optional[List[str]] = None + # Terminate the conversation, it is a special action + # If terminate is True, it means the conversation is over, it will stop the + # conversation loop forcibly. + terminate: Optional[bool] = None + # Memory fragments of current conversation, we can recover the conversation at any + # time. + memory_fragments: Optional[Dict[str, Any]] = None @model_validator(mode="before") @classmethod @@ -81,10 +89,11 @@ class ActionOutput(BaseModel): class Action(ABC, Generic[T]): """Base Action class for defining agent actions.""" - def __init__(self, language: str = "en"): + def __init__(self, language: str = "en", name: Optional[str] = None): """Create an action.""" self.resource: Optional[Resource] = None self.language: str = language + self._name = name def init_resource(self, resource: Optional[Resource]): """Initialize the resource.""" @@ -95,6 +104,21 @@ class Action(ABC, Generic[T]): """Return the resource type needed for the action.""" return None + @property + def name(self) -> str: + """Return the action name.""" + if self._name: + return self._name + _name = self.__class__.__name__ + if _name.endswith("Action"): + return _name[:-6] + return _name + + @classmethod + def get_action_description(cls) -> str: + """Return the action description.""" + return cls.__doc__ or "" + @property def render_protocol(self) -> Optional[Vis]: """Return the render protocol.""" @@ -185,6 +209,20 @@ class Action(ABC, Generic[T]): typed_cls = cast(Type[BaseModel], cls) return typed_cls.model_validate(json_result) + @classmethod + def parse_action( + cls, + ai_message: str, + default_action: "Action", + resource: Optional[Resource] = None, + **kwargs, + ) -> Optional["Action"]: + """Parse the action from the message. + + If you want skip the action, return None. + """ + return default_action + @abstractmethod async def run( self, diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/agent_manage.py b/packages/dbgpt-core/src/dbgpt/agent/core/agent_manage.py index 4215bafbf..36b5b9ba5 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/agent_manage.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/agent_manage.py @@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Set, Tuple, Type, cast from dbgpt.component import BaseComponent, ComponentType, SystemApp -from ..expand.Indicator_assistant_agent import IndicatorAssistantAgent -from ..expand.simple_assistant_agent import SimpleAssistantAgent from .agent import Agent from .base_agent import ConversableAgent @@ -67,22 +65,11 @@ class AgentManager(BaseComponent): def after_start(self): """Register all agents.""" - from ..expand.code_assistant_agent import CodeAssistantAgent - from ..expand.dashboard_assistant_agent import DashboardAssistantAgent - from ..expand.data_scientist_agent import DataScientistAgent - from ..expand.summary_assistant_agent import SummaryAssistantAgent - from ..expand.tool_assistant_agent import ToolAssistantAgent + core_agents = scan_agents() + for _, agent in core_agents.items(): + self.register_agent(agent) - core_agents = set() - core_agents.add(self.register_agent(CodeAssistantAgent)) - core_agents.add(self.register_agent(DashboardAssistantAgent)) - core_agents.add(self.register_agent(DataScientistAgent)) - core_agents.add(self.register_agent(SummaryAssistantAgent)) - core_agents.add(self.register_agent(ToolAssistantAgent)) - core_agents.add(self.register_agent(IndicatorAssistantAgent)) - core_agents.add(self.register_agent(SimpleAssistantAgent)) - - self._core_agents = core_agents + self._core_agents = list(core_agents.values()) def register_agent( self, cls: Type[ConversableAgent], ignore_duplicate: bool = False @@ -163,3 +150,27 @@ def get_agent_manager(system_app: Optional[SystemApp] = None) -> AgentManager: initialize_agent(system_app) app = system_app or _SYSTEM_APP return AgentManager.get_instance(cast(SystemApp, app)) + + +_HAS_SCAN = False + + +def scan_agents(): + """Scan and register all agents.""" + from dbgpt.util.module_utils import ModelScanner, ScannerConfig + + from .base_agent import ConversableAgent + + global _HAS_SCAN + + if _HAS_SCAN: + return + scanner = ModelScanner[ConversableAgent]() + config = ScannerConfig( + module_path="dbgpt.agent.expand", + base_class=ConversableAgent, + recursive=True, + ) + scanner.scan_and_register(config) + _HAS_SCAN = True + return scanner.get_registered_items() diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py b/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py index 7255b99da..be5c34872 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py @@ -19,6 +19,7 @@ from dbgpt.util.tracer import SpanType, root_tracer from dbgpt.util.utils import colored from ..resource.base import Resource +from ..util.conv_utils import parse_conv_id from ..util.llm.llm import LLMConfig, LLMStrategyType from ..util.llm.llm_client import AIWrapper from .action.base import Action, ActionOutput @@ -27,7 +28,7 @@ from .memory.agent_memory import AgentMemory from .memory.gpts.base import GptsMessage from .memory.gpts.gpts_memory import GptsMemory from .profile.base import ProfileConfig -from .role import Role +from .role import AgentRunMode, Role logger = logging.getLogger(__name__) @@ -42,6 +43,7 @@ class ConversableAgent(Role, Agent): resource: Optional[Resource] = Field(None, description="Resource") llm_config: Optional[LLMConfig] = None bind_prompt: Optional[PromptTemplate] = None + run_mode: Optional[AgentRunMode] = Field(default=None, description="Run mode") max_retry_count: int = 3 llm_client: Optional[AIWrapper] = None # 确认当前Agent是否需要进行流式输出 @@ -158,24 +160,21 @@ class ConversableAgent(Role, Agent): if not self.llm_config or not self.llm_config.llm_client: raise ValueError("LLM client is not initialized!") self.llm_client = AIWrapper(llm_client=self.llm_config.llm_client) + real_conv_id, _ = parse_conv_id(self.not_null_agent_context.conv_id) + memory_session = f"{real_conv_id}_{self.role}_{self.name}" self.memory.initialize( self.name, self.llm_config.llm_client, importance_scorer=self.memory_importance_scorer, insight_extractor=self.memory_insight_extractor, + session_id=memory_session, ) # Clone the memory structure self.memory = self.memory.structure_clone() - # init agent memory - if is_retry_chat: - # recover agent memory message - agent_history_memories = ( - await self.memory.gpts_memory.get_agent_history_memory( - self.not_null_agent_context.conv_id, self.role - ) - ) - for agent_history_memory in agent_history_memories: - await self.write_memories(**agent_history_memory) + action_outputs = await self.memory.gpts_memory.get_agent_history_memory( + real_conv_id, self.role + ) + await self.recovering_memory(action_outputs) return self def bind(self, target: Any) -> "ConversableAgent": @@ -194,6 +193,17 @@ class ConversableAgent(Role, Agent): self.profile = target elif isinstance(target, type) and issubclass(target, Action): self.actions.append(target()) + elif isinstance(target, Action): + self.actions.append(target) + elif isinstance(target, list) and all( + [isinstance(item, type) and issubclass(item, Action) for item in target] + ): + for action in target: + self.actions.append(action()) + elif isinstance(target, list) and all( + [isinstance(item, Action) for item in target] + ): + self.actions.extend(target) elif isinstance(target, PromptTemplate): self.bind_prompt = target @@ -333,30 +343,48 @@ class ConversableAgent(Role, Agent): ), }, ) + reply_message = None try: with root_tracer.start_span( "agent.generate_reply._init_reply_message", ) as span: # initialize reply message - reply_message: AgentMessage = self._init_reply_message( - received_message=received_message - ) + a_reply_message: Optional[ + AgentMessage + ] = await self._a_init_reply_message(received_message=received_message) + if a_reply_message: + reply_message = a_reply_message + else: + reply_message = self._init_reply_message( + received_message=received_message + ) span.metadata["reply_message"] = reply_message.to_dict() fail_reason = None current_retry_counter = 0 is_success = True - while current_retry_counter < self.max_retry_count: + done = False + observation = received_message.content or "" + while not done and current_retry_counter < self.max_retry_count: if current_retry_counter > 0: - retry_message = self._init_reply_message( + a_reply_message: Optional[ + AgentMessage + ] = await self._a_init_reply_message( received_message=received_message, rely_messages=rely_messages, ) + if a_reply_message: + retry_message = a_reply_message + else: + retry_message = self._init_reply_message( + received_message=received_message, + rely_messages=rely_messages, + ) retry_message.rounds = reply_message.rounds + 1 - retry_message.content = fail_reason + retry_message.content = fail_reason or observation retry_message.current_goal = received_message.current_goal # The current message is a self-optimized message that needs to be @@ -463,13 +491,8 @@ class ConversableAgent(Role, Agent): # 5.Optimize wrong answers myself if not check_pass: if not act_out.have_retry: + logger.warning("No retry available!") break - current_retry_counter += 1 - # Send error messages and issue new problem-solving instructions - if current_retry_counter < self.max_retry_count: - await self.send( - reply_message, sender, reviewer, request_reply=False - ) fail_reason = reason await self.write_memories( question=question, @@ -479,13 +502,26 @@ class ConversableAgent(Role, Agent): check_fail_reason=fail_reason, ) else: + # Successful reply + observation = act_out.observations await self.write_memories( question=question, ai_message=ai_message, action_output=act_out, check_pass=check_pass, ) - break + if self.run_mode != AgentRunMode.LOOP or act_out.terminate: + logger.debug(f"Agent {self.name} reply success!{reply_message}") + break + + # Continue to run the next round + current_retry_counter += 1 + # Send error messages and issue new problem-solving instructions + if current_retry_counter < self.max_retry_count: + await self.send( + reply_message, sender, reviewer, request_reply=False + ) + reply_message.success = is_success # 6.final message adjustment await self.adjust_final_message(is_success, reply_message) @@ -497,7 +533,8 @@ class ConversableAgent(Role, Agent): err_message.success = False return err_message finally: - root_span.metadata["reply_message"] = reply_message.to_dict() + if reply_message: + root_span.metadata["reply_message"] = reply_message.to_dict() root_span.end() async def thinking( @@ -583,7 +620,14 @@ class ConversableAgent(Role, Agent): "total_action": len(self.actions), }, ) as span: - last_out = await action.run( + ai_message = message.content if message.content else "" + real_action = action.parse_action( + ai_message, default_action=action, **kwargs + ) + if real_action is None: + continue + + last_out = await real_action.run( ai_message=message.content if message.content else "", resource=None, rely_action_out=last_out, @@ -931,6 +975,17 @@ class ConversableAgent(Role, Agent): rounds=received_message.rounds + 1, ) + async def _a_init_reply_message( + self, + received_message: AgentMessage, + rely_messages: Optional[List[AgentMessage]] = None, + ) -> Optional[AgentMessage]: + """Create a new message from the received message. + + If return not None, the `_init_reply_message` method will not be called. + """ + return None + def _convert_to_ai_message( self, gpts_messages: List[GptsMessage], @@ -1025,6 +1080,7 @@ class ConversableAgent(Role, Agent): if not observation: raise ValueError("The received message content is empty!") memories = await self.read_memories(observation) + has_memories = True if memories else False reply_message_str = "" if context is None: context = {} @@ -1070,6 +1126,8 @@ class ConversableAgent(Role, Agent): resource_vars=resource_vars, **context, ) + if not user_prompt: + user_prompt = "Observation: " agent_messages = [] if system_prompt: @@ -1079,21 +1137,22 @@ class ConversableAgent(Role, Agent): role=ModelMessageRoleType.SYSTEM, ) ) - # 关联上下文的历史消息 - if historical_dialogues: + if historical_dialogues and not has_memories: + # If we can't read the memory, we need to rely on the historical dialogue for i in range(len(historical_dialogues)): if i % 2 == 0: - # 偶数开始, 偶数是用户信息 + # The even number starts, and the even number is the user + # information message = historical_dialogues[i] message.role = ModelMessageRoleType.HUMAN agent_messages.append(message) else: - # 奇数是AI信息 + # The odd number is AI information message = historical_dialogues[i] message.role = ModelMessageRoleType.AI agent_messages.append(message) - # 当前的用户输入信息 + # Current user input information agent_messages.append( AgentMessage( content=user_prompt, diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/agent_memory.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/agent_memory.py index 370e17404..a18f5b5ba 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/agent_memory.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/agent_memory.py @@ -224,6 +224,7 @@ class AgentMemory(Memory[AgentMemoryFragment]): importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None, insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None, real_memory_fragment_class: Optional[Type[AgentMemoryFragment]] = None, + session_id: Optional[str] = None, ) -> None: """Initialize the memory.""" self.memory.initialize( @@ -233,6 +234,7 @@ class AgentMemory(Memory[AgentMemoryFragment]): insight_extractor=insight_extractor or self.insight_extractor, real_memory_fragment_class=real_memory_fragment_class or AgentMemoryFragment, + session_id=session_id, ) @mutable @@ -245,6 +247,15 @@ class AgentMemory(Memory[AgentMemoryFragment]): """Write a memory fragment to the memory.""" return await self.memory.write(memory_fragment, now) + @mutable + async def write_batch( + self, + memory_fragments: List[AgentMemoryFragment], + now: Optional[datetime] = None, + ) -> Optional[DiscardedMemoryFragments[AgentMemoryFragment]]: + """Write a batch of memory fragments to the memory.""" + return await self.memory.write_batch(memory_fragments, now) + @immutable async def read( self, diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/base.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/base.py index fd3f1d442..ddcadfe59 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/base.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/base.py @@ -345,6 +345,8 @@ class Memory(ABC, Generic[T]): insight_extractor: Optional[InsightExtractor] = None _real_memory_fragment_class: Optional[Type[T]] = None importance_weight: float = 0.15 + # The session id is used to identify the session of the agent. + session_id: Optional[str] = None @mutable def initialize( @@ -354,6 +356,7 @@ class Memory(ABC, Generic[T]): importance_scorer: Optional[ImportanceScorer] = None, insight_extractor: Optional[InsightExtractor] = None, real_memory_fragment_class: Optional[Type[T]] = None, + session_id: Optional[str] = None, ) -> None: """Initialize memory. @@ -364,6 +367,7 @@ class Memory(ABC, Generic[T]): self.importance_scorer = importance_scorer self.insight_extractor = insight_extractor self._real_memory_fragment_class = real_memory_fragment_class + self.session_id = session_id @abstractmethod @immutable @@ -400,6 +404,7 @@ class Memory(ABC, Generic[T]): self.importance_scorer = memory.importance_scorer self.insight_extractor = memory.insight_extractor self._real_memory_fragment_class = memory._real_memory_fragment_class + self.session_id = memory.session_id @abstractmethod @mutable @@ -442,7 +447,22 @@ class Memory(ABC, Generic[T]): Optional[DiscardedMemoryFragments]: The discarded memory fragments, None means no memory fragments are discarded. """ - raise NotImplementedError + discarded_memory_fragments = [] + discarded_insights = [] + for memory_fragment in memory_fragments: + discarded_memory = await self.write(memory_fragment, now) + if discarded_memory: + if discarded_memory.discarded_memory_fragments: + discarded_memory_fragments.extend( + discarded_memory.discarded_memory_fragments + ) + if discarded_memory.discarded_insights: + discarded_insights.extend(discarded_memory.discarded_insights) + return ( + DiscardedMemoryFragments(discarded_memory_fragments, discarded_insights) + if discarded_memory_fragments + else None + ) @abstractmethod @immutable @@ -698,7 +718,9 @@ class ShortTermMemory(Memory, Generic[T]): self: "ShortTermMemory[T]", now: Optional[datetime] = None ) -> "ShortTermMemory[T]": """Return a structure clone of the memory.""" - m: ShortTermMemory[T] = ShortTermMemory(buffer_size=self._buffer_size) + m: ShortTermMemory[T] = ShortTermMemory( + buffer_size=self._buffer_size, + ) m._copy_from(self) return m diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/base.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/base.py index 073296a04..0508f37d5 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/base.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/base.py @@ -258,3 +258,11 @@ class GptsMessageMemory(ABC): Returns: GptsMessage: The last message in the conversation """ + + @abstractmethod + def delete_by_conv_id(self, conv_id: str) -> None: + """Delete messages by conversation id. + + Args: + conv_id(str): Conversation id + """ diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/default_gpts_memory.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/default_gpts_memory.py index b95dbc11b..11758a1d9 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/default_gpts_memory.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/default_gpts_memory.py @@ -147,3 +147,7 @@ class DefaultGptsMessageMemory(GptsMessageMemory): def get_last_message(self, conv_id: str) -> Optional[GptsMessage]: """Get the last message in the conversation.""" return None + + def delete_by_conv_id(self, conv_id: str) -> None: + """Delete all messages in the conversation.""" + self.df.drop(self.df[self.df["conv_id"] == conv_id].index, inplace=True) diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/gpts_memory.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/gpts_memory.py index a8c07666f..c4b4c3fec 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/gpts_memory.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/gpts/gpts_memory.py @@ -5,8 +5,10 @@ import json import logging from asyncio import Queue from collections import defaultdict +from concurrent.futures import Executor, ThreadPoolExecutor from typing import Dict, List, Optional, Union +from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.vis.client import VisAgentMessages, VisAgentPlans, VisAppLink, vis_client from ...action.base import ActionOutput @@ -26,6 +28,7 @@ class GptsMemory: self, plans_memory: Optional[GptsPlansMemory] = None, message_memory: Optional[GptsMessageMemory] = None, + executor: Optional[Executor] = None, ): """Create a memory to store plans and messages.""" self._plans_memory: GptsPlansMemory = ( @@ -34,7 +37,7 @@ class GptsMemory: self._message_memory: GptsMessageMemory = ( message_memory if message_memory is not None else DefaultGptsMessageMemory() ) - + self._executor = executor or ThreadPoolExecutor(max_workers=2) self.messages_cache: defaultdict = defaultdict(list) self.channels: defaultdict = defaultdict(Queue) self.enable_vis_map: defaultdict = defaultdict(bool) @@ -118,10 +121,10 @@ class GptsMemory: async def append_message(self, conv_id: str, message: GptsMessage): """Append message.""" - # 中期记忆 self.messages_cache[conv_id].append(message) - # 长期记忆 - self.message_memory.append(message) + await blocking_func_to_async( + self._executor, self.message_memory.append, message + ) # 消息记忆后发布消息 await self.push_message(conv_id) @@ -130,7 +133,9 @@ class GptsMemory: """Get message by conv_id.""" messages = self.messages_cache[conv_id] if not messages: - messages = self.message_memory.get_by_conv_id(conv_id) + messages = await blocking_func_to_async( + self._executor, self.message_memory.get_by_conv_id, conv_id + ) return messages async def get_agent_messages( @@ -144,28 +149,34 @@ class GptsMemory: result.append(gpt_message) return result - async def get_agent_history_memory(self, conv_id: str, agent_role: str) -> List: + async def get_agent_history_memory( + self, conv_id: str, agent_role: str + ) -> List[ActionOutput]: """Get agent history memory.""" - gpt_messages = self.messages_cache[conv_id] - agent_messages = [] - for gpt_message in gpt_messages: - if gpt_message.sender == agent_role or gpt_message.receiver == agent_role: - agent_messages.append(gpt_message) - - new_list = [ - { - "question": agent_messages[i].content, - "ai_message": agent_messages[i + 1].content, - "action_output": ActionOutput.from_dict( + agent_messages = await blocking_func_to_async( + self._executor, self.message_memory.get_by_agent, conv_id, agent_role + ) + new_list = [] + for i in range(0, len(agent_messages), 2): + if i + 1 >= len(agent_messages): + break + action_report = None + if agent_messages[i + 1].action_report: + action_report = ActionOutput.from_dict( json.loads(agent_messages[i + 1].action_report) - ), - "check_pass": agent_messages[i + 1].is_success, - } - for i in range(0, len(agent_messages), 2) - ] + ) + new_list.append( + { + "question": agent_messages[i].content, + "ai_message": agent_messages[i + 1].content, + "action_output": action_report, + "check_pass": agent_messages[i + 1].is_success, + } + ) - return new_list + # Just use the action_output now + return [m["action_output"] for m in new_list if m["action_output"]] async def _message_group_vis_build(self, message_group, vis_items: list): num: int = 0 @@ -260,7 +271,9 @@ class GptsMemory: if messages_cache and len(messages_cache) > 0: messages = messages_cache else: - messages = self.message_memory.get_by_conv_id(conv_id=conv_id) + messages = await blocking_func_to_async( + self._executor, self.message_memory.get_by_conv_id, conv_id=conv_id + ) simple_message_list = [] for message in messages: @@ -299,7 +312,9 @@ class GptsMemory: ) messages = messages_cache[start_round:] else: - messages = self.message_memory.get_by_conv_id(conv_id=conv_id) + messages = await blocking_func_to_async( + self._executor, self.message_memory.get_by_conv_id, conv_id=conv_id + ) # VIS消息组装 temp_group: Dict = {} diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/hybrid.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/hybrid.py index 0cac076e8..2d46d3951 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/hybrid.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/hybrid.py @@ -81,7 +81,7 @@ class HybridMemory(Memory, Generic[T]): ): """Create a hybrid memory from Chroma vector store.""" from dbgpt.configs.model_config import DATA_DIR - from dbgpt.storage.vector_store.chroma_store import ( + from dbgpt_ext.storage.vector_store.chroma_store import ( ChromaStore, ChromaVectorConfig, ) @@ -152,6 +152,7 @@ class HybridMemory(Memory, Generic[T]): importance_scorer: Optional[ImportanceScorer[T]] = None, insight_extractor: Optional[InsightExtractor[T]] = None, real_memory_fragment_class: Optional[Type[T]] = None, + session_id: Optional[str] = None, ) -> None: """Initialize the memory. @@ -168,6 +169,7 @@ class HybridMemory(Memory, Generic[T]): "importance_scorer": importance_scorer, "insight_extractor": insight_extractor, "real_memory_fragment_class": real_memory_fragment_class, + "session_id": session_id, } for memory in memories: memory.initialize(**kwargs) @@ -181,8 +183,25 @@ class HybridMemory(Memory, Generic[T]): op: WriteOperation = WriteOperation.ADD, ) -> Optional[DiscardedMemoryFragments[T]]: """Write a memory fragment to the memory.""" + return await self._write_single( + memory_fragment, + now=now, + op=op, + write_long_term=True, + ) + + async def _write_single( + self, + memory_fragment: T, + now: Optional[datetime] = None, + op: WriteOperation = WriteOperation.ADD, + write_long_term: bool = True, + ) -> Optional[DiscardedMemoryFragments[T]]: + """Write a single memory fragment to the memory.""" # First write to sensory memory - sen_discarded_memories = await self._sensory_memory.write(memory_fragment) + sen_discarded_memories = await self._sensory_memory.write( + memory_fragment, now=now, op=op + ) if not sen_discarded_memories: return None short_term_discarded_memories = [] @@ -190,7 +209,9 @@ class HybridMemory(Memory, Generic[T]): discarded_insights = [] for sen_memory in sen_discarded_memories.discarded_memory_fragments: # Write to short term memory - short_discarded_memory = await self._short_term_memory.write(sen_memory) + short_discarded_memory = await self._short_term_memory.write( + sen_memory, now=now, op=op + ) if short_discarded_memory: short_term_discarded_memories.append(short_discarded_memory) discarded_memory_fragments.extend( @@ -199,17 +220,36 @@ class HybridMemory(Memory, Generic[T]): for insight in short_discarded_memory.discarded_insights: # Just keep the first insight discarded_insights.append(insight.insights[0]) + if not write_long_term: + return None # Obtain the importance of insights insight_scores = await self.score_memory_importance(discarded_insights) # Get the importance of insights for i, ins in enumerate(discarded_insights): ins.update_importance(insight_scores[i]) all_memories = discarded_memory_fragments + discarded_insights - if self._long_term_memory: + if self._long_term_memory and len(all_memories) > 0: # Write to long term memory await self._long_term_memory.write_batch(all_memories, self.now) return None + @mutable + async def write_batch( + self, memory_fragments: List[T], now: Optional[datetime] = None + ) -> Optional[DiscardedMemoryFragments[T]]: + """Write a batch of memory fragments to the memory. + + For memory recovery, we only write to sensory memory and short term memory. + """ + for memory_fragment in memory_fragments: + # Just write to sensory memory and short term memory + await self._write_single( + memory_fragment, + now=now, + op=WriteOperation.ADD, + write_long_term=False, + ) + @immutable async def read( self, diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/long_term.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/long_term.py index 80e1c46c2..54147aea3 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/long_term.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/long_term.py @@ -2,12 +2,12 @@ from concurrent.futures import Executor from datetime import datetime -from typing import Generic, List, Optional +from typing import Any, Dict, Generic, List, Optional from dbgpt.core import Chunk from dbgpt.rag.retriever.time_weighted import TimeWeightedEmbeddingRetriever from dbgpt.storage.vector_store.base import VectorStoreBase -from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters from dbgpt.util.annotations import immutable, mutable from dbgpt.util.executor_utils import blocking_func_to_async @@ -17,14 +17,20 @@ _FORGET_PLACEHOLDER = "[FORGET]" _MERGE_PLACEHOLDER = "[MERGE]" _METADATA_BUFFER_IDX = "buffer_idx" _METADATA_LAST_ACCESSED_AT = "last_accessed_at" +_METADATA_SESSION_ID = "session_id" _METADAT_IMPORTANCE = "importance" class LongTermRetriever(TimeWeightedEmbeddingRetriever): - """Long-term retriever.""" + """Long-term retriever with persistence support.""" def __init__(self, now: datetime, **kwargs): - """Create a long-term retriever.""" + """Create a long-term retriever. + + Args: + now: Current datetime to use for time-based calculations + **kwargs: Additional arguments passed to TimeWeightedEmbeddingRetriever + """ self.now = now super().__init__(**kwargs) @@ -32,35 +38,143 @@ class LongTermRetriever(TimeWeightedEmbeddingRetriever): def _retrieve( self, query: str, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: - """Retrieve memories.""" + """Retrieve memories based on query and time weights. + + Args: + query: The query string + filters: Optional metadata filters + + Returns: + List of relevant document chunks + """ + # Use the current time from self.now instead of generating a new one current_time = self.now - docs_and_scores = { - doc.metadata[_METADATA_BUFFER_IDX]: (doc, self.default_salience) - # Calculate for all memories. - for doc in self.memory_stream - } + + if self._use_vector_store_only: + # If operating in vector store only mode, use parent class implementation + # with custom adjustments for long-term memory + return self._retrieve_vector_store_only(query, filters, current_time) + + # Process all memories in memory_stream + docs_and_scores = {} + for doc in self.memory_stream: + if _METADATA_BUFFER_IDX in doc.metadata: + buffer_idx = doc.metadata[_METADATA_BUFFER_IDX] + docs_and_scores[buffer_idx] = (doc, self.default_salience) + # If a doc is considered salient, update the salience score - docs_and_scores.update(self.get_salient_docs(query)) + docs_and_scores.update(self.get_salient_docs(query, filters)) + + # If no documents found and we're in vector store only mode, fall back + if not docs_and_scores and self._use_vector_store_only: + return self._retrieve_vector_store_only(query, filters, current_time) + + # Calculate combined scores for all documents rescored_docs = [ (doc, self._get_combined_score(doc, relevance, current_time)) for doc, relevance in docs_and_scores.values() ] + + # Sort by score rescored_docs.sort(key=lambda x: x[1], reverse=True) + result = [] - # Ensure frequently accessed memories aren't forgotten retrieved_num = 0 + + # Process documents in order of score for doc, _ in rescored_docs: + # Skip documents that are marked for forgetting or merging if ( retrieved_num < self._k and doc.content.find(_FORGET_PLACEHOLDER) == -1 and doc.content.find(_MERGE_PLACEHOLDER) == -1 ): retrieved_num += 1 - buffered_doc = self.memory_stream[doc.metadata[_METADATA_BUFFER_IDX]] - buffered_doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time - result.append(buffered_doc) + + # Get the document from memory stream + if _METADATA_BUFFER_IDX in doc.metadata and 0 <= doc.metadata[ + _METADATA_BUFFER_IDX + ] < len(self.memory_stream): + buffered_doc = self.memory_stream[ + doc.metadata[_METADATA_BUFFER_IDX] + ] + buffered_doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time + result.append(buffered_doc) + else: + # Handle case where buffer_idx is invalid + doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time + result.append(doc) + + # Save memory stream after updating access times + self._save_memory_stream() + return result + def _retrieve_vector_store_only( + self, + query: str, + filters: Optional[MetadataFilters] = None, + current_time: Optional[datetime] = None, + ) -> List[Chunk]: + """Retrieve documents using only vector store when memory_stream is unavailable. + + Args: + query: The query string + filters: Optional metadata filters + + Returns: + List of relevant document chunks + """ + # Get documents from vector store + docs = self._index_store.similar_search_with_scores( + query, topk=self._top_k * 2, score_threshold=0, filters=filters + ) + + # Filter out documents that are marked for forgetting or merging + filtered_docs = [ + doc + for doc in docs + if ( + doc.content.find(_FORGET_PLACEHOLDER) == -1 + and doc.content.find(_MERGE_PLACEHOLDER) == -1 + ) + ] + + # Apply time weighting + rescored_docs = [] + for doc in filtered_docs: + if _METADATA_LAST_ACCESSED_AT in doc.metadata: + last_accessed_time = doc.metadata[_METADATA_LAST_ACCESSED_AT] + hours_passed = self._get_hours_passed(current_time, last_accessed_time) + time_score = (1.0 - self.decay_rate) ** hours_passed + + # Add importance score if available + importance_score = 0 + if _METADAT_IMPORTANCE in doc.metadata: + importance_score = doc.metadata[_METADAT_IMPORTANCE] + + # Combine scores + combined_score = doc.score + time_score + importance_score + rescored_docs.append((doc, combined_score)) + else: + # Just use vector similarity if no time data + rescored_docs.append((doc, doc.score)) + + # Sort by combined score + rescored_docs.sort(key=lambda x: x[1], reverse=True) + + # Return top results, updating last_accessed_at + result = [] + for doc, _ in rescored_docs[: self._k]: + doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time + result.append(doc) + + return result + + def _get_hours_passed(self, time: datetime, ref_time: datetime) -> float: + """Get the hours passed between two datetime objects.""" + return (time - ref_time).total_seconds() / 3600 + class LongTermMemory(Memory, Generic[T]): """Long-term memory.""" @@ -74,6 +188,7 @@ class LongTermMemory(Memory, Generic[T]): now: Optional[datetime] = None, reflection_threshold: Optional[float] = None, _default_importance: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None, ): """Create a long-term memory.""" self.now = now or datetime.now() @@ -87,6 +202,7 @@ class LongTermMemory(Memory, Generic[T]): now=self.now, index_store=vector_store ) self._default_importance = _default_importance + self._metadata: Dict[str, Any] = metadata or {"memory_type": "long_term"} @immutable def structure_clone( @@ -124,13 +240,15 @@ class LongTermMemory(Memory, Generic[T]): self.aggregate_importance += importance memory_idx = len(self.memory_retriever.memory_stream) + metadata = self._metadata + metadata[_METADAT_IMPORTANCE] = importance + metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time + if self.session_id: + metadata[_METADATA_SESSION_ID] = self.session_id + document = Chunk( - page_content="[{}] ".format(memory_idx) - + str(memory_fragment.raw_observation), - metadata={ - _METADAT_IMPORTANCE: importance, - _METADATA_LAST_ACCESSED_AT: last_accessed_time, - }, + content="[{}] ".format(memory_idx) + str(memory_fragment.raw_observation), + metadata=metadata, ) await blocking_func_to_async( self.executor, @@ -174,10 +292,21 @@ class LongTermMemory(Memory, Generic[T]): """Fetch memories related to the observation.""" # TODO: Mock now? retrieved_memories = [] + filters = [] + for key, value in self._metadata.items(): + # Just handle str, int, float + if isinstance(value, (str, int, float)): + filters.append(MetadataFilter(key=key, value=value)) + if self.session_id: + filters.append( + MetadataFilter(key=_METADATA_SESSION_ID, value=self.session_id) + ) + filters = MetadataFilters(filters=filters) retrieved_list = await blocking_func_to_async( self.executor, self.memory_retriever.retrieve, observation, + filters=filters, ) for retrieved_chunk in retrieved_list: retrieved_memories.append( diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/memory/short_term.py b/packages/dbgpt-core/src/dbgpt/agent/core/memory/short_term.py index a6cf85bb5..c96fe5c7f 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/memory/short_term.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/memory/short_term.py @@ -26,7 +26,7 @@ class EnhancedShortTermMemory(ShortTermMemory[T]): self, embeddings: Embeddings, executor: Executor, - buffer_size: int = 2, + buffer_size: int = 10, enhance_similarity_threshold: float = 0.7, enhance_threshold: int = 3, ): @@ -74,29 +74,31 @@ class EnhancedShortTermMemory(ShortTermMemory[T]): self._embeddings.embed_documents, ) memory_fragment.update_embeddings(memory_fragment_embeddings) - for idx, memory_embedding in enumerate(self.short_embeddings): - similarity = await blocking_func_to_async( - self._executor, - cosine_similarity, - memory_embedding, - memory_fragment_embeddings, - ) - # Sigmoid probability, transform similarity to [0, 1] - sigmoid_prob: float = await blocking_func_to_async( - self._executor, sigmoid_function, similarity - ) - if ( - sigmoid_prob >= self.enhance_similarity_threshold - and random.random() < sigmoid_prob - ): - self.enhance_cnt[idx] += 1 - self.enhance_memories[idx].append(memory_fragment) - discard_memories = await self.transfer_to_long_term(memory_fragment) - if op == WriteOperation.ADD: - self._fragments.append(memory_fragment) - self.short_embeddings.append(memory_fragment_embeddings) - await self.handle_overflow(self._fragments) - return discard_memories + + async with self._lock: + for idx, memory_embedding in enumerate(self.short_embeddings): + similarity = await blocking_func_to_async( + self._executor, + cosine_similarity, + memory_embedding, + memory_fragment_embeddings, + ) + # Sigmoid probability, transform similarity to [0, 1] + sigmoid_prob: float = await blocking_func_to_async( + self._executor, sigmoid_function, similarity + ) + if ( + sigmoid_prob >= self.enhance_similarity_threshold + and random.random() < sigmoid_prob + ): + self.enhance_cnt[idx] += 1 + self.enhance_memories[idx].append(memory_fragment) + discard_memories = await self.transfer_to_long_term(memory_fragment) + if op == WriteOperation.ADD: + self._fragments.append(memory_fragment) + self.short_embeddings.append(memory_fragment_embeddings) + await self.handle_overflow(self._fragments) + return discard_memories @mutable async def transfer_to_long_term( @@ -163,11 +165,11 @@ class EnhancedShortTermMemory(ShortTermMemory[T]): Discard the least important memory fragment if the buffer size exceeds. """ - if len(self.short_term_memories) > self._buffer_size: + discarded_memories = [] + if len(self._fragments) > self._buffer_size: id2fragments: Dict[int, Dict] = {} - for idx in range(len(self.short_term_memories) - 1): - # Not discard the last one - memory = self.short_term_memories[idx] + for idx in range(len(self._fragments) - 1): + memory = self._fragments[idx] id2fragments[idx] = { "enhance_count": self.enhance_cnt[idx], "importance": memory.importance, @@ -180,26 +182,42 @@ class EnhancedShortTermMemory(ShortTermMemory[T]): id2fragments[x]["enhance_count"], ), ) + # Get the ID of the memory fragment to be popped pop_id = sorted_ids[0] - pop_raw_observation = self.short_term_memories[pop_id].raw_observation - self.enhance_cnt.pop(pop_id) - self.enhance_cnt.append(0) - self.enhance_memories.pop(pop_id) - self.enhance_memories.append([]) - - discard_memory = self._fragments.pop(pop_id) + pop_memory = self._fragments[pop_id] + pop_raw_observation = pop_memory.raw_observation + # Save the discarded memory + discarded_memory = self._fragments.pop(pop_id) + discarded_memories.append(discarded_memory) + # Remove the corresponding embedding vector self.short_embeddings.pop(pop_id) - # remove the discard_memory from other short-term memory's enhanced list - for idx in range(len(self.short_term_memories)): - current_enhance_memories: List[T] = self.enhance_memories[idx] - to_remove_idx = [] - for i, ehf in enumerate(current_enhance_memories): - if ehf.raw_observation == pop_raw_observation: - to_remove_idx.append(i) - for i in to_remove_idx: - current_enhance_memories.pop(i) - self.enhance_cnt[idx] -= len(to_remove_idx) + # Reorganize enhance count and enhance memories + new_enhance_memories = [[] for _ in range(self._buffer_size)] + new_enhance_cnt = [0 for _ in range(self._buffer_size)] + # Copy and adjust enhanced memory and count + current_idx = 0 + for idx in range(len(self._fragments)): + if idx == pop_id: + continue # Skip the popped memory - return memory_fragments, [discard_memory] - return memory_fragments, [] + # Copy the enhanced memory list but remove any items matching the + # popped memory + current_memories = [] + removed_count = 0 + + for ehf in self.enhance_memories[idx]: + if ehf.raw_observation != pop_raw_observation: + current_memories.append(ehf) + else: + removed_count += 1 + + # Update to new array + new_enhance_memories[current_idx] = current_memories + new_enhance_cnt[current_idx] = self.enhance_cnt[idx] - removed_count + current_idx += 1 + # Update enhanced memories and counts + self.enhance_memories = new_enhance_memories + self.enhance_cnt = new_enhance_cnt + + return memory_fragments, discarded_memories diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/role.py b/packages/dbgpt-core/src/dbgpt/agent/core/role.py index e5d35037e..74c0fa56e 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/role.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/role.py @@ -1,6 +1,7 @@ """Role class for role-based conversation.""" from abc import ABC +from enum import Enum from typing import Dict, List, Optional from jinja2 import Environment, Template, meta @@ -14,6 +15,15 @@ from .memory.llm import LLMImportanceScorer, LLMInsightExtractor from .profile import Profile, ProfileConfig +class AgentRunMode(str, Enum): + """Agent run mode.""" + + DEFAULT = "default" + # Run the agent in loop mode, until the conversation is over(Maximum retries or + # encounter a stop signal) + LOOP = "loop" + + class Role(ABC, BaseModel): """Role class for role-based conversation.""" @@ -216,7 +226,7 @@ class Role(ABC, BaseModel): action_output: Optional[ActionOutput] = None, check_pass: bool = True, check_fail_reason: Optional[str] = None, - ) -> None: + ) -> AgentMemoryFragment: """Write the memories to the memory. We suggest you to override this method to save the conversation to memory @@ -228,6 +238,9 @@ class Role(ABC, BaseModel): action_output(ActionOutput): The action output. check_pass(bool): Whether the check pass. check_fail_reason(str): The check fail reason. + + Returns: + AgentMemoryFragment: The memory fragment created. """ if not action_output: raise ValueError("Action output is required to save to memory.") @@ -245,3 +258,23 @@ class Role(ABC, BaseModel): memory_content = self._render_template(write_memory_template, **memory_map) fragment = AgentMemoryFragment(memory_content) await self.memory.write(fragment) + + action_output.memory_fragments = { + "memory": fragment.raw_observation, + "id": fragment.id, + "importance": fragment.importance, + } + return fragment + + async def recovering_memory(self, action_outputs: List[ActionOutput]) -> None: + """Recover the memory from the action outputs.""" + fragments = [] + for action_output in action_outputs: + if action_output.memory_fragments: + fragment = AgentMemoryFragment.build_from( + observation=action_output.memory_fragments["memory"], + importance=action_output.memory_fragments.get("importance"), + memory_id=action_output.memory_fragments.get("id"), + ) + fragments.append(fragment) + await self.memory.write_batch(fragments) diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py index 8c6feb744..784566d54 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py @@ -2,53 +2,186 @@ import json import logging from typing import Optional -from dbgpt.agent import ResourceType -from dbgpt.agent.expand.actions.tool_action import ToolAction, ToolInput -from dbgpt.vis import Vis, VisPlugin +from dbgpt.agent import Action, ActionOutput, AgentResource, Resource, ResourceType + +from ...resource.tool.base import BaseTool, ToolParameter +from ...util.react_parser import ReActOutputParser, ReActStep +from .tool_action import ToolAction, run_tool logger = logging.getLogger(__name__) +class Terminate(Action[None], BaseTool): + """Terminate action. + + It is a special action to terminate the conversation, at same time, it can be a + tool to return the final answer. + """ + + async def run( + self, + ai_message: str, + resource: Optional[AgentResource] = None, + rely_action_out: Optional[ActionOutput] = None, + need_vis_render: bool = True, + **kwargs, + ) -> ActionOutput: + return ActionOutput( + is_exe_success=True, + terminate=True, + content=ai_message, + ) + + @classmethod + def get_action_description(cls) -> str: + return ( + "Terminate action representing the task is finished, or you think it is" + " impossible for you to complete the task" + ) + + @classmethod + def parse_action( + cls, + ai_message: str, + default_action: "Action", + resource: Optional[Resource] = None, + **kwargs, + ) -> Optional["Action"]: + """Parse the action from the message. + + If you want skip the action, return None. + """ + if "parser" in kwargs and isinstance(kwargs["parser"], ReActOutputParser): + parser = kwargs["parser"] + else: + parser = ReActOutputParser() + steps = parser.parse(ai_message) + if len(steps) != 1: + return None + step: ReActStep = steps[0] + if not step.action: + return None + if step.action.lower() == default_action.name.lower(): + return default_action + return None + + @property + def name(self): + return "terminate" + + @property + def description(self): + return self.get_action_description() + + @property + def args(self): + return { + "output": ToolParameter( + type="string", + name="output", + description=( + "Final answer to the task, or the reason why you think it " + "is impossible to complete the task" + ), + ), + } + + def execute(self, *args, **kwargs): + if "output" in kwargs: + return kwargs["output"] + if "final_answer" in kwargs: + return kwargs["final_answer"] + return args[0] if args else "terminate unknown" + + async def async_execute(self, *args, **kwargs): + return self.execute(*args, **kwargs) + + class ReActAction(ToolAction): - """ReAct action class.""" + """React action class.""" def __init__(self, **kwargs): """Tool action init.""" super().__init__(**kwargs) - self._render_protocol = VisPlugin() @property def resource_need(self) -> Optional[ResourceType]: """Return the resource type needed for the action.""" - return ResourceType.Tool + return None - @property - def render_protocol(self) -> Optional[Vis]: - """Return the render protocol.""" - return self._render_protocol + @classmethod + def parse_action( + cls, + ai_message: str, + default_action: "ReActAction", + resource: Optional[Resource] = None, + **kwargs, + ) -> Optional["ReActAction"]: + """Parse the action from the message. - @property - def out_model_type(self): - """Return the output model type.""" - return ToolInput - - @property - def ai_out_schema(self) -> Optional[str]: - """Return the AI output schema.""" - out_put_schema = { - "Thought": "Summary of thoughts to the user", - "Action": { - "tool_name": "The name of a tool that can be used to answer " - "the current" - "question or solve the current task.", - "args": { - "arg name1": "arg value1", - "arg name2": "arg value2", - }, - }, - } - - return f"""Please response in the following json format: - {json.dumps(out_put_schema, indent=2, ensure_ascii=False)} - Make sure the response is correct json and can be parsed by Python json.loads. + If you want skip the action, return None. """ + return default_action + + async def run( + self, + ai_message: str, + resource: Optional[AgentResource] = None, + rely_action_out: Optional[ActionOutput] = None, + need_vis_render: bool = True, + **kwargs, + ) -> ActionOutput: + """Perform the action.""" + + if "parser" in kwargs and isinstance(kwargs["parser"], ReActOutputParser): + parser = kwargs["parser"] + else: + parser = ReActOutputParser() + steps = parser.parse(ai_message) + if len(steps) != 1: + raise ValueError("Only one action is allowed each time.") + step = steps[0] + act_out = await self._do_run(ai_message, step, need_vis_render=need_vis_render) + if not act_out.action: + act_out.action = step.action + if step.thought: + act_out.thoughts = step.thought + if ( + not act_out.action_input + and step.action_input + and isinstance(step.action_input, str) + ): + act_out.action_input = step.action_input + return act_out + + async def _do_run( + self, + ai_message: str, + parsed_step: ReActStep, + need_vis_render: bool = True, + ) -> ActionOutput: + """Perform the action.""" + tool_args = {} + name = parsed_step.action + action_input = parsed_step.action_input + action_input_str = action_input + try: + if action_input and isinstance(action_input, str): + tool_args = json.loads(action_input) + elif isinstance(action_input, dict): + tool_args = action_input + action_input_str = json.dumps(action_input, ensure_ascii=False) + except json.JSONDecodeError: + if parsed_step.action == "terminate": + tool_args = {"output": action_input} + logger.warning(f"Failed to parse the args: {action_input}") + act_out = await run_tool( + name, + tool_args, + self.resource, + self.render_protocol, + need_vis_render=need_vis_render, + ) + if not act_out.action_input: + act_out.action_input = action_input_str + return act_out diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py index 1f8dd63d4..96338a6f6 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py @@ -9,7 +9,7 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin from ...core.action.base import Action, ActionOutput from ...core.schema import Status -from ...resource.base import AgentResource, ResourceType +from ...resource.base import AgentResource, Resource, ResourceType from ...resource.tool.pack import ToolPack logger = logging.getLogger(__name__) @@ -99,48 +99,69 @@ class ToolAction(Action[ToolInput]): is_exe_success=False, content="The requested correctly structured answer could not be found.", ) + return await run_tool( + param.tool_name, + param.args, + self.resource, + self.render_protocol, + need_vis_render=need_vis_render, + ) + +async def run_tool( + name: str, + args: dict, + resource: Resource, + render_protocol: Optional[Vis] = None, + need_vis_render: bool = False, +) -> ActionOutput: + """Run the tool.""" + is_terminal = None + try: + tool_packs = ToolPack.from_resource(resource) + if not tool_packs: + raise ValueError("The tool resource is not found!") + tool_pack = tool_packs[0] + response_success = True + status = Status.RUNNING.value + err_msg = None try: - tool_packs = ToolPack.from_resource(self.resource) - if not tool_packs: - raise ValueError("The tool resource is not found!") - tool_pack = tool_packs[0] - response_success = True - status = Status.RUNNING.value - err_msg = None - try: - tool_result = await tool_pack.async_execute( - resource_name=param.tool_name, **param.args - ) - status = Status.COMPLETE.value - except Exception as e: - response_success = False - logger.exception(f"Tool [{param.tool_name}] execute failed!") - status = Status.FAILED.value - err_msg = f"Tool [{param.tool_name}] execute failed! {str(e)}" - tool_result = err_msg - - plugin_param = { - "name": param.tool_name, - "args": param.args, - "status": status, - "logo": None, - "result": str(tool_result), - "err_msg": err_msg, - } - if not self.render_protocol: - raise NotImplementedError("The render_protocol should be implemented.") - - view = await self.render_protocol.display(content=plugin_param) - - return ActionOutput( - is_exe_success=response_success, - content=str(tool_result), - view=view, - observations=str(tool_result), - ) + tool_result = await tool_pack.async_execute(resource_name=name, **args) + status = Status.COMPLETE.value + is_terminal = tool_pack.is_terminal(name) except Exception as e: - logger.exception("Tool Action Run Failed!") - return ActionOutput( - is_exe_success=False, content=f"Tool action run failed!{str(e)}" - ) + response_success = False + logger.exception(f"Tool [{name}] execute failed!") + status = Status.FAILED.value + err_msg = f"Tool [{name}] execute failed! {str(e)}" + tool_result = err_msg + + plugin_param = { + "name": name, + "args": args, + "status": status, + "logo": None, + "result": str(tool_result), + "err_msg": err_msg, + } + if render_protocol: + view = await render_protocol.display(content=plugin_param) + elif need_vis_render: + raise NotImplementedError("The render_protocol should be implemented.") + else: + view = None + + return ActionOutput( + is_exe_success=response_success, + content=str(tool_result), + view=view, + observations=str(tool_result), + terminate=is_terminal, + ) + except Exception as e: + logger.exception("Tool Action Run Failed!") + return ActionOutput( + is_exe_success=False, + content=f"Tool action run failed!{str(e)}", + terminate=is_terminal, + ) diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py b/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py index 1154b80d1..284366582 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py @@ -1,7 +1,7 @@ -import json import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional +from dbgpt._private.pydantic import Field from dbgpt.agent import ( ActionOutput, Agent, @@ -9,90 +9,84 @@ from dbgpt.agent import ( AgentMessage, ConversableAgent, ProfileConfig, + Resource, ResourceType, ) -from dbgpt.agent.expand.actions.react_action import ReActAction -from dbgpt.core import ModelMessageRoleType +from dbgpt.agent.core.role import AgentRunMode +from dbgpt.agent.resource import BaseTool, ToolPack +from dbgpt.agent.util.react_parser import ReActOutputParser from dbgpt.util.configure import DynConfig -from dbgpt.util.json_utils import find_json_objects + +from .actions.react_action import ReActAction logger = logging.getLogger(__name__) + +_REACT_DEFAULT_GOAL = """Answer the following questions or solve the tasks by \ +selecting the right ACTION from the ACTION SPACE as best as you can. +# ACTION SPACE Simple Description # +{{ action_space_simple_desc }} +""" + _REACT_SYSTEM_TEMPLATE = """\ -You are a {{ role }}, {% if name %}named {{ name }}. -{% endif %}your goal is {% if is_retry_chat %}{{ retry_goal }} -{% else %}{{ goal }} -{% endif %}.\ -At the same time, please strictly abide by the constraints and specifications -in the "IMPORTANT REMINDER" below. -{% if resource_prompt %}\ +You are a {{ role }}, {% if name %}named {{ name }}. {% endif %}\ +{{ goal }} + +You can only use one action in the actions provided in the ACTION SPACE to solve the \ +task. For each step, you must output an Action; it cannot be empty. The maximum number \ +of steps you can take is {{ max_steps }}. +Do not output an empty string! + # ACTION SPACE # -{{ resource_prompt }} -{% endif %} -{% if expand_prompt %}\ -{{ expand_prompt }} -{% endif %}\ +{{ action_space }} +# RESPONSE FROMAT # +For each task input, your response should contain: +1. One analysis of the task and the current environment, reasoning to determine the \ +next action (prefix "Thought: "). +2. One action string in the ACTION SPACE (prefix "Action: "), should be one of \ +[{{ action_space_names }}]. +3. One action input (prefix "Action Input: "), empty if no input is required. -# IMPORTANT REMINDER # -The current time is:{{now_time}}. -{% if constraints %}\ -{% for constraint in constraints %}\ -{{ loop.index }}. {{ constraint }} -{% endfor %}\ -{% endif %}\ - - -{% if is_retry_chat %}\ -{% if retry_constraints %}\ -{% for retry_constraint in retry_constraints %}\ -{{ loop.index }}. {{ retry_constraint }} -{% endfor %}\ -{% endif %}\ -{% else %}\ - - - -{% endif %}\ - - - -{% if examples %}\ # EXAMPLE INTERACTION # -You can refer to the following examples: -{{ examples }}\ -{% endif %}\ +Observation: ...(This is output provided by the external environment or Action output, \ +you are not allowed to generate it.) -{% if most_recent_memories %}\ -# History of Solving Task# -{{ most_recent_memories }}\ -{% endif %}\ - -# RESPONSE FORMAT # -{% if out_schema %} {{ out_schema }} {% endif %}\ +Thought: ... +Action: ... +Action Input: ... ################### TASK ################### -Please solve the task: +Please Solve this task: + +{{ question }}\ + +Please answer in the same language as the user's question. +The current time is: {{ now_time }}. +""" +_REACT_USER_TEMPLATE = """\ +{% if most_recent_memories %}\ +Most recent message: +{{ most_recent_memories }} +{% endif %}\ + +{% if question %}\ +Question: {{ question }} +{% endif %} """ _REACT_WRITE_MEMORY_TEMPLATE = """\ -{% if question %}Question: {{ question }} {% endif %} -{% if assistant %}Assistant: {{ assistant }} {% endif %} +{% if thought %}Thought: {{ thought }} {% endif %} +{% if action %}Action: {{ action }} {% endif %} +{% if action_input %}Action Input: {{ action_input }} {% endif %} {% if observation %}Observation: {{ observation }} {% endif %} """ class ReActAgent(ConversableAgent): - end_action_name: str = DynConfig( - "terminate", - category="agent", - key="dbgpt_agent_expand_plugin_assistant_agent_end_action_name", - ) - max_steps: int = DynConfig( - 10, - category="agent", - key="dbgpt_agent_expand_plugin_assistant_agent_max_steps", - ) + max_retry_count: int = 15 + run_mode: AgentRunMode = AgentRunMode.LOOP + profile: ProfileConfig = ProfileConfig( name=DynConfig( "ReAct", @@ -100,198 +94,127 @@ class ReActAgent(ConversableAgent): key="dbgpt_agent_expand_plugin_assistant_agent_name", ), role=DynConfig( - "ToolMaster", + "ReActToolMaster", category="agent", key="dbgpt_agent_expand_plugin_assistant_agent_role", ), goal=DynConfig( - "Read and understand the tool information given in the action space " - "below to understand their capabilities and how to use them,and choosing " - "the right tool to solve the task", + _REACT_DEFAULT_GOAL, category="agent", key="dbgpt_agent_expand_plugin_assistant_agent_goal", ), - constraints=DynConfig( - [ - "Achieve the goal step by step." - "Each step, please read the parameter definition of the tool carefully " - "and extract the specific parameters required to execute the tool " - "from the user's goal.", - "information in json format according to the following required format." - "If there is an example, please refer to the sample format output.", - "each step, you can only select one tool in action space.", - ], - category="agent", - key="dbgpt_agent_expand_plugin_assistant_agent_constraints", - ), system_prompt_template=_REACT_SYSTEM_TEMPLATE, + user_prompt_template=_REACT_USER_TEMPLATE, write_memory_template=_REACT_WRITE_MEMORY_TEMPLATE, ) + parser: ReActOutputParser = Field(default_factory=ReActOutputParser) def __init__(self, **kwargs): + """Init indicator AssistantAgent.""" super().__init__(**kwargs) + self._init_actions([ReActAction]) - async def review(self, message: Optional[str], censored: Agent) -> Tuple[bool, Any]: - """Review the message based on the censored message.""" - try: - json_obj = find_json_objects(message) - if len(json_obj) == 0: - raise ValueError( - "No correct json object found in the message。" - "Please strictly output JSON in the defined " - "format, and only one action can be ouput each time. " - ) - return True, json_obj[0] - except Exception as e: - logger.error(f"review error: {e}") - raise e - - def validate_action(self, action_name: str) -> bool: - tools = self.resource.get_resource_by_type(ResourceType.Tool) - for tool in tools: - if tool.name == action_name: - return True - raise ValueError(f"{action_name} is not in the action space.") - - async def generate_reply( + async def _a_init_reply_message( self, received_message: AgentMessage, + rely_messages: Optional[List[AgentMessage]] = None, + ) -> AgentMessage: + reply_message = super()._init_reply_message(received_message, rely_messages) + + tool_packs = ToolPack.from_resource(self.resource) + action_space = [] + action_space_names = [] + action_space_simple_desc = [] + if tool_packs: + tool_pack = tool_packs[0] + for tool in tool_pack.sub_resources: + tool_desc, _ = await tool.get_prompt(lang=self.language) + action_space_names.append(tool.name) + action_space.append(tool_desc) + if isinstance(tool, BaseTool): + tool_simple_desc = tool.description + else: + tool_simple_desc = tool.get_prompt() + action_space_simple_desc.append(f"{tool.name}: {tool_simple_desc}") + else: + for action in self.actions: + action_space_names.append(action.name) + action_space.append(action.get_action_description()) + # self.actions + reply_message.context = { + "max_steps": self.max_retry_count, + "action_space": "\n".join(action_space), + "action_space_names": ", ".join(action_space_names), + "action_space_simple_desc": "\n".join(action_space_simple_desc), + } + return reply_message + + async def load_resource(self, question: str, is_retry_chat: bool = False): + """Load agent bind resource.""" + if self.resource: + + def _remove_tool(r: Resource): + if r.type() == ResourceType.Tool: + return None + return r + + # Remove all tools from the resource + # We will handle tools separately + new_resource = self.resource.apply(apply_func=_remove_tool) + if new_resource: + resource_prompt, resource_reference = await new_resource.get_prompt( + lang=self.language, question=question + ) + return resource_prompt, resource_reference + return None, None + + def prepare_act_param( + self, + received_message: Optional[AgentMessage], + sender: Agent, + rely_messages: Optional[List[AgentMessage]] = None, + **kwargs, + ) -> Dict[str, Any]: + """Prepare the parameters for the act method.""" + return { + "parser": self.parser, + } + + async def act( + self, + message: AgentMessage, sender: Agent, reviewer: Optional[Agent] = None, - rely_messages: Optional[List[AgentMessage]] = None, - historical_dialogues: Optional[List[AgentMessage]] = None, is_retry_chat: bool = False, last_speaker_name: Optional[str] = None, **kwargs, - ) -> AgentMessage: - """Generate a reply based on the received messages.""" + ) -> ActionOutput: + """Perform actions.""" + message_content = message.content + if not message_content: + raise ValueError("The response is empty.") try: - logger.info( - f"generate agent reply!sender={sender}, " - f"rely_messages_len={rely_messages}" - ) - self.validate_action(self.end_action_name) - observation = AgentMessage(content="please start!") - reply_message: AgentMessage = self._init_reply_message( - received_message=received_message - ) - thinking_messages, resource_info = await self._load_thinking_messages( - received_message=observation, - sender=sender, - rely_messages=rely_messages, - historical_dialogues=historical_dialogues, - context=reply_message.get_dict_context(), - is_retry_chat=is_retry_chat, - ) - # attach current task to system prompt - thinking_messages[0].content = ( - thinking_messages[0].content + "\n" + received_message.content - ) - done = False - max_steps = self.max_steps - await self.write_memories( - question=received_message.content, - ai_message="", - ) - while not done and max_steps > 0: - ai_message = "" - try: - # 1. thinking - llm_reply, model_name = await self.thinking( - thinking_messages, sender - ) - reply_message.model_name = model_name - reply_message.resource_info = resource_info - ai_message = llm_reply - thinking_messages.append( - AgentMessage(role=ModelMessageRoleType.AI, content=llm_reply) - ) - approve, json_obj = await self.review(llm_reply, self) - logger.info(f"jons_obj: {json_obj}") - action = json_obj["Action"] - thought = json_obj["Thought"] - action.update({"thought": thought}) - reply_message.content = json.dumps(action, ensure_ascii=False) - tool_name = action["tool_name"] - self.validate_action(tool_name) - # 2. act - act_extent_param = self.prepare_act_param( - received_message=received_message, - sender=sender, - rely_messages=rely_messages, - historical_dialogues=historical_dialogues, - ) - act_out: ActionOutput = await self.act( - message=reply_message, - sender=sender, - reviewer=reviewer, - is_retry_chat=is_retry_chat, - last_speaker_name=last_speaker_name, - **act_extent_param, - ) - if act_out: - reply_message.action_report = act_out - - # 3. obs - check_pass, reason = await self.verify( - reply_message, sender, reviewer - ) - done = tool_name == self.end_action_name and check_pass - if check_pass: - logger.info(f"Observation:{act_out.content}") - thinking_messages.append( - AgentMessage( - role=ModelMessageRoleType.HUMAN, - content=f"Observation: {tool_name} " - f"output:{act_out.content}\n", - ) - ) - await self.write_memories( - question="", - ai_message=ai_message, - action_output=act_out, - check_pass=check_pass, - ) - else: - observation = f"Observation: {reason}" - logger.info(f"Observation:{observation}") - thinking_messages.append( - AgentMessage( - role=ModelMessageRoleType.HUMAN, content=observation - ) - ) - await self.write_memories( - question="", - ai_message=ai_message, - check_pass=check_pass, - check_fail_reason=reason, - ) - max_steps -= 1 - except Exception as e: - fail_reason = ( - f"Observation: Exception occurs:({type(e).__name__}){e}." - ) - logger.error(fail_reason) - thinking_messages.append( - AgentMessage( - role=ModelMessageRoleType.HUMAN, content=fail_reason - ) - ) - await self.write_memories( - question="", - ai_message=ai_message, - check_pass=False, - check_fail_reason=fail_reason, - ) - reply_message.success = done - await self.adjust_final_message(True, reply_message) - return reply_message + steps = self.parser.parse(message_content) + err_msg = None + if not steps: + err_msg = "No correct response found." + elif len(steps) != 1: + err_msg = "Only one action is allowed each time." + if err_msg: + return ActionOutput(is_exe_success=False, content=err_msg) except Exception as e: - logger.exception("Generate reply exception!") - err_message = AgentMessage(content=str(e)) - err_message.success = False - return err_message + logger.warning(f"review error: {e}") + + action_output = await super().act( + message=message, + sender=sender, + reviewer=reviewer, + is_retry_chat=is_retry_chat, + last_speaker_name=last_speaker_name, + **kwargs, + ) + return action_output async def write_memories( self, @@ -300,7 +223,7 @@ class ReActAgent(ConversableAgent): action_output: Optional[ActionOutput] = None, check_pass: bool = True, check_fail_reason: Optional[str] = None, - ) -> None: + ) -> AgentMemoryFragment: """Write the memories to the memory. We suggest you to override this method to save the conversation to memory @@ -312,18 +235,33 @@ class ReActAgent(ConversableAgent): action_output(ActionOutput): The action output. check_pass(bool): Whether the check pass. check_fail_reason(str): The check fail reason. + + Returns: + AgentMemoryFragment: The memory fragment created. """ - observation = "" - if action_output and action_output.observations: - observation = action_output.observations - elif check_fail_reason: - observation = check_fail_reason + if not action_output: + raise ValueError("Action output is required to save to memory.") + + mem_thoughts = action_output.thoughts or ai_message + action = action_output.action + action_input = action_output.action_input + observation = check_fail_reason or action_output.observations + memory_map = { - "question": question, - "assistant": ai_message, + "thought": mem_thoughts, + "action": action, "observation": observation, } + if action_input: + memory_map["action_input"] = action_input + write_memory_template = self.write_memory_template memory_content = self._render_template(write_memory_template, **memory_map) fragment = AgentMemoryFragment(memory_content) await self.memory.write(fragment) + action_output.memory_fragments = { + "memory": fragment.raw_observation, + "id": fragment.id, + "importance": fragment.importance, + } + return fragment diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/base.py b/packages/dbgpt-core/src/dbgpt/agent/resource/base.py index 53c153fc0..d28b7edfd 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/base.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/base.py @@ -4,7 +4,19 @@ import dataclasses import json from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) from pydantic import field_validator @@ -228,6 +240,13 @@ class Resource(ABC, Generic[P]): resources.append(resource) return resources + def apply( + self, + apply_func: Callable[["Resource"], Union["Resource", List["Resource"], None]], + ) -> Union["Resource", None]: + """Apply the function to the resource.""" + return self + class AgentResource(BaseModel): """Agent resource class.""" diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py b/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py index 7e947cb5e..de411b263 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py @@ -4,8 +4,9 @@ Resource pack is a collection of resources(also, it is a resource) that can be e together. """ +import copy import dataclasses -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from .base import Resource, ResourceParameters, ResourceType @@ -117,3 +118,38 @@ class ResourcePack(Resource[PackResourceParameters]): def sub_resources(self) -> List[Resource]: """Return the resources.""" return list(self._resources.values()) + + def apply( + self, apply_func: Callable[[Resource], Union[Resource, List[Resource], None]] + ) -> Union[Resource, None]: + """Apply the function to the resource.""" + if not self.is_pack: + return self + + def _apply_func_to_resource( + resource: Resource, + ) -> Union[Resource, List[Resource], None]: + if resource.is_pack: + resources = [] + resource_copy = cast(ResourcePack, copy.copy(resource)) + for resource_copy in resource_copy.sub_resources: + result = _apply_func_to_resource(resource_copy) + if result: + if isinstance(result, list): + resources.extend(result) + else: + resources.append(result) + # Replace the resources + resource_copy._resources = { + resource.name: resource for resource in resources + } + else: + return apply_func(resource) + + new_resource = _apply_func_to_resource(self) + resource_copy = cast(ResourcePack, copy.copy(self)) + if isinstance(new_resource, list): + resource_copy._resources = { + resource.name: resource for resource in new_resource + } + return new_resource diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py b/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py index 0e1665e1c..b53f4e6c7 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py @@ -1,7 +1,7 @@ """Tool resource pack module.""" import os -from typing import Any, Callable, Dict, List, Optional, Type, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast from mcp import ClientSession from mcp.client.sse import sse_client @@ -24,15 +24,22 @@ def _is_function_tool(resources: Any) -> bool: ) +def _is_tool(resources: Any) -> bool: + return isinstance(resources, BaseTool) or _is_function_tool(resources) + + def _to_tool_list(resources: ToolResourceType) -> List[BaseTool]: if isinstance(resources, BaseTool): return [resources] - elif isinstance(resources, list) and all( - isinstance(r, BaseTool) for r in resources - ): - return cast(List[BaseTool], resources) - elif isinstance(resources, list) and all(_is_function_tool(r) for r in resources): - return [cast(FunctionTool, getattr(r, "_tool")) for r in resources] + elif isinstance(resources, Sequence) and all(_is_tool(r) for r in resources): + new_resources = [] + for r in resources: + if isinstance(r, BaseTool): + new_resources.append(r) + else: + function_tool = cast(FunctionTool, getattr(r, "_tool")) + new_resources.append(function_tool) + return new_resources elif _is_function_tool(resources): function_tool = cast(FunctionTool, getattr(resources, "_tool")) return [function_tool] @@ -191,6 +198,15 @@ class ToolPack(ResourcePack): except Exception as e: raise ToolExecutionException(f"Execution error: {str(e)}") + def is_terminal(self, resource_name: Optional[str] = None) -> bool: + """Check if the tool is terminal.""" + from ...expand.actions.react_action import Terminate + + if not resource_name: + return False + tl = self._get_execution_tool(resource_name) + return isinstance(tl, Terminate) + class AutoGPTPluginToolPack(ToolPack): """Auto-GPT plugin tool pack class.""" diff --git a/packages/dbgpt-core/src/dbgpt/agent/util/conv_utils.py b/packages/dbgpt-core/src/dbgpt/agent/util/conv_utils.py new file mode 100644 index 000000000..6f7abbd26 --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/agent/util/conv_utils.py @@ -0,0 +1,17 @@ +import re +from typing import Optional, Tuple + + +def parse_conv_id(conv_id: str) -> Tuple[str, Optional[int]]: + pattern = r"([\w-]+)_(\d+)" + match = re.match(pattern, conv_id) + if match: + # TODO: conv_id passed from serve module will be like "real_conv_id_1" now, + # so we need to extract + # the real conv_id + real_conv_id = match.group(1) + # Extract the number part + number_part = match.group(2) + return real_conv_id, number_part + else: + return conv_id, None diff --git a/packages/dbgpt-core/src/dbgpt/agent/util/react_parser.py b/packages/dbgpt-core/src/dbgpt/agent/util/react_parser.py new file mode 100644 index 000000000..b0b6ff182 --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/agent/util/react_parser.py @@ -0,0 +1,212 @@ +import json +import re +from dataclasses import dataclass +from typing import Any, List, Optional + + +@dataclass +class ReActStep: + """ + Dataclass representing a single step in the ReAct pattern. + """ + + thought: Optional[str] = None + action: Optional[str] = None + action_input: Optional[Any] = None + observation: Optional[Any] = None + is_terminal: bool = False + + +class ReActOutputParser: + """ + Parser for ReAct format model outputs with configurable prefixes. + + This parser extracts structured information from language model outputs + that follow the ReAct pattern: Thought -> Action -> Action Input -> Observation. + """ + + def __init__( + self, + thought_prefix: str = "Thought:", + action_prefix: str = "Action:", + action_input_prefix: str = "Action Input:", + observation_prefix: str = "Observation:", + terminate_action: str = "terminate", + ): + """ + Initialize the ReAct output parser with configurable prefixes. + + Args: + thought_prefix: Prefix string that indicates the start of a thought. + action_prefix: Prefix string that indicates the start of an action. + action_input_prefix: Prefix string that indicates the start of action input. + observation_prefix: Prefix string that indicates the start of an + observation. + terminate_action: String that indicates termination action. + """ + self.thought_prefix = thought_prefix + self.action_prefix = action_prefix + self.action_input_prefix = action_input_prefix + self.observation_prefix = observation_prefix + self.terminate_action = terminate_action + + # Escape special regex characters in prefixes + self.thought_prefix_escaped = re.escape(thought_prefix) + self.action_prefix_escaped = re.escape(action_prefix) + self.action_input_prefix_escaped = re.escape(action_input_prefix) + self.observation_prefix_escaped = re.escape(observation_prefix) + + def parse(self, text: str) -> List[ReActStep]: + """ + Parse the ReAct format output text into structured steps. + + Args: + text: The text to parse, containing ReAct formatted content. + + Returns: + List of ReActStep dataclasses, each containing thought, action, + action_input, and observation. + """ + # Split the text into steps based on thought prefix + steps = [] + + # Remove any leading/trailing whitespace + text = text.strip() + + # Find all instances of the thought prefix + thought_matches = list(re.finditer(rf"{self.thought_prefix_escaped}\s*", text)) + + if not thought_matches: + return [] + + # Process each thought section + for i, match in enumerate(thought_matches): + start_pos = match.start() + + # Determine end position (either next thought or end of text) + if i < len(thought_matches) - 1: + end_pos = thought_matches[i + 1].start() + else: + end_pos = len(text) + + # Extract the current step's text + step_text = text[start_pos:end_pos].strip() + + # Parse the step + step_data = self._parse_step(step_text) + if step_data: + steps.append(step_data) + + return steps + + def _parse_step(self, step_text: str) -> Optional[ReActStep]: + """ + Parse a single step of the ReAct format. + + Args: + step_text: Text containing a single thought-action-input-observation + sequence. + + Returns: + ReActStep dataclass with thought, action, action_input, and observation, + or None if parsing fails. + """ + # Initialize the result + thought = None + action = None + action_input = None + observation = None + is_terminal = False + + # Extract thought + thought_match = re.search( + rf"{self.thought_prefix_escaped}\s*(.*?)(?={self.action_prefix_escaped}|{self.observation_prefix_escaped}|$)", + step_text, + re.DOTALL, + ) + if thought_match: + thought = thought_match.group(1).strip() + + # Extract action + action_match = re.search( + rf"{self.action_prefix_escaped}\s*(.*?)(?={self.action_input_prefix_escaped}|{self.observation_prefix_escaped}|$)", + step_text, + re.DOTALL, + ) + if action_match: + action = action_match.group(1).strip() + + # Check if this is a terminate action + is_terminal = action.lower() == self.terminate_action.lower() + + # Extract action input + action_input_match = re.search( + rf"{self.action_input_prefix_escaped}\s*(.*?)(?={self.observation_prefix_escaped}|{self.thought_prefix_escaped}|$)", + step_text, + re.DOTALL, + ) + if action_input_match: + action_input_text = action_input_match.group(1).strip() + + # Try to parse action input as JSON if it looks like JSON + if ( + action_input_text.startswith("{") and action_input_text.endswith("}") + ) or ( + action_input_text.startswith("[") and action_input_text.endswith("]") + ): + try: + action_input = json.loads(action_input_text) + except json.JSONDecodeError: + action_input = action_input_text + else: + action_input = action_input_text + + # Extract observation + observation_match = re.search( + rf"{self.observation_prefix_escaped}\s*(.*?)(?={self.thought_prefix_escaped}|$)", + step_text, + re.DOTALL, + ) + if observation_match: + observation_text = observation_match.group(1).strip() + + # Try to parse observation as JSON if it looks like JSON + if ( + observation_text.startswith("{") and observation_text.endswith("}") + ) or (observation_text.startswith("[") and observation_text.endswith("]")): + try: + observation = json.loads(observation_text) + except json.JSONDecodeError: + observation = observation_text + else: + observation = observation_text + + # Only return if we have at least thought or action + if thought or action: + return ReActStep( + thought=thought, + action=action, + action_input=action_input, + observation=observation, + is_terminal=is_terminal, + ) + return None + + def get_final_output(self, steps: List[ReActStep]) -> Optional[str]: + """ + Get the final output from a terminate action if it exists. + + Args: + steps: List of parsed steps. + + Returns: + The final output string or None if no terminate action is found. + """ + for step in reversed(steps): # Look from the end + if step.is_terminal and step.action == self.terminate_action: + if ( + isinstance(step.action_input, dict) + and "output" in step.action_input + ): + return step.action_input["output"] + return None diff --git a/packages/dbgpt-core/src/dbgpt/agent/util/tests/test_react_parser.py b/packages/dbgpt-core/src/dbgpt/agent/util/tests/test_react_parser.py new file mode 100644 index 000000000..718e2ea51 --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/agent/util/tests/test_react_parser.py @@ -0,0 +1,404 @@ +""" +Unit tests for the ReActOutputParser using pytest. +""" + +from ..react_parser import ReActOutputParser + + +class TestReActOutputParser: + """Test suite for the ReActOutputParser using pytest.""" + + def test_basic_parsing(self): + """Test basic parsing of a simple ReAct output.""" + parser = ReActOutputParser() + text = """Thought: I should calculate 2+2. +Action: calculator +Action Input: {"operation": "add", "a": 2, "b": 2}""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I should calculate 2+2." + assert steps[0].action == "calculator" + assert steps[0].action_input == {"operation": "add", "a": 2, "b": 2} + assert steps[0].observation is None + assert steps[0].is_terminal is False + + def test_parsing_with_observation(self): + """Test parsing with observation included.""" + parser = ReActOutputParser() + text = """Thought: I should calculate 2+2. +Action: calculator +Action Input: {"operation": "add", "a": 2, "b": 2} +Observation: 4""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I should calculate 2+2." + assert steps[0].action == "calculator" + assert steps[0].action_input == {"operation": "add", "a": 2, "b": 2} + assert steps[0].observation == "4" + assert steps[0].is_terminal is False + + def test_terminal_action(self): + """Test parsing of a terminal action.""" + parser = ReActOutputParser() + text = """Thought: I've finished the calculation. +Action: terminate +Action Input: {"output": "The answer is 4"}""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I've finished the calculation." + assert steps[0].action == "terminate" + assert steps[0].action_input == {"output": "The answer is 4"} + assert steps[0].is_terminal is True + + # Test get_final_output + final_output = parser.get_final_output(steps) + assert final_output == "The answer is 4" + + def test_multi_step_parsing(self): + """Test parsing of multiple steps.""" + parser = ReActOutputParser() + text = """Thought: I need to calculate 10 * 5. +Action: calculator +Action Input: {"operation": "multiply", "a": 10, "b": 5} +Observation: 50 + +Thought: Now I need to add 20 to the result. +Action: calculator +Action Input: {"operation": "add", "a": 50, "b": 20} +Observation: 70 + +Thought: The calculation is complete. +Action: terminate +Action Input: {"output": "10 * 5 + 20 = 70"}""" + + steps = parser.parse(text) + + assert len(steps) == 3 + + # Check first step + assert steps[0].thought == "I need to calculate 10 * 5." + assert steps[0].action == "calculator" + assert steps[0].action_input == {"operation": "multiply", "a": 10, "b": 5} + assert steps[0].observation == "50" + + # Check second step + assert steps[1].thought == "Now I need to add 20 to the result." + assert steps[1].action == "calculator" + assert steps[1].action_input == {"operation": "add", "a": 50, "b": 20} + assert steps[1].observation == "70" + + # Check third step + assert steps[2].thought == "The calculation is complete." + assert steps[2].action == "terminate" + assert steps[2].action_input == {"output": "10 * 5 + 20 = 70"} + assert steps[2].is_terminal is True + + # Test get_final_output + final_output = parser.get_final_output(steps) + assert final_output == "10 * 5 + 20 = 70" + + def test_custom_prefixes(self): + """Test parsing with custom prefixes.""" + parser = ReActOutputParser( + thought_prefix="Think:", + action_prefix="Do:", + action_input_prefix="With:", + observation_prefix="Result:", + terminate_action="finish", + ) + + text = """Think: I should calculate 5 + 10. +Do: calculate +With: {"x": 5, "y": 10} +Result: 15 + +Think: Now I'm done. +Do: finish +With: {"output": "The sum is 15"}""" + + steps = parser.parse(text) + + assert len(steps) == 2 + + # Check first step + assert steps[0].thought == "I should calculate 5 + 10." + assert steps[0].action == "calculate" + assert steps[0].action_input == {"x": 5, "y": 10} + assert steps[0].observation == "15" + assert steps[0].is_terminal is False + + # Check second step + assert steps[1].thought == "Now I'm done." + assert steps[1].action == "finish" + assert steps[1].action_input == {"output": "The sum is 15"} + assert steps[1].is_terminal is True + + # Test get_final_output + final_output = parser.get_final_output(steps) + assert final_output == "The sum is 15" + + def test_non_json_action_input(self): + """Test parsing of non-JSON action inputs.""" + parser = ReActOutputParser() + text = """Thought: I'll search for information. +Action: search +Action Input: python programming language""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I'll search for information." + assert steps[0].action == "search" + assert steps[0].action_input == "python programming language" + + def test_non_json_observation(self): + """Test parsing of non-JSON observations.""" + parser = ReActOutputParser() + text = """Thought: I'll search for information. +Action: search +Action Input: {"query": "python programming language"} +Observation: Python is a high-level, general-purpose programming language.""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I'll search for information." + assert steps[0].action == "search" + assert steps[0].action_input == {"query": "python programming language"} + assert ( + steps[0].observation + == "Python is a high-level, general-purpose programming language." + ) + + def test_missing_components(self): + """Test parsing when some components are missing.""" + parser = ReActOutputParser() + + # Missing action input + text1 = """Thought: I'll search for information. +Action: search""" + + steps = parser.parse(text1) + assert len(steps) == 1 + assert steps[0].thought == "I'll search for information." + assert steps[0].action == "search" + assert steps[0].action_input is None + + # Only thought + text2 = """Thought: I'm thinking about what to do next.""" + + steps = parser.parse(text2) + assert len(steps) == 1 + assert steps[0].thought == "I'm thinking about what to do next." + assert steps[0].action is None + + def test_invalid_json_handling(self): + """Test parsing when JSON is invalid.""" + parser = ReActOutputParser() + text = """Thought: I'll do a calculation. +Action: calculator +Action Input: {"operation": "add", "a": 2, "b": 3,}""" # Invalid JSON (extra comma) + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I'll do a calculation." + assert steps[0].action == "calculator" + # Should keep as string when JSON parsing fails + assert steps[0].action_input == """{"operation": "add", "a": 2, "b": 3,}""" + + def test_multiple_json_blocks(self): + """Test parsing when multiple JSON objects are in the text.""" + parser = ReActOutputParser() + text = """Thought: I need to check multiple values. +Action: check_values +Action Input: [{"name": "first", "value": 100}, {"name": "second", "value": 200}]""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I need to check multiple values." + assert steps[0].action == "check_values" + assert isinstance(steps[0].action_input, list) + assert len(steps[0].action_input) == 2 + assert steps[0].action_input[0]["name"] == "first" + assert steps[0].action_input[1]["value"] == 200 + + def test_empty_input(self): + """Test parsing with empty input.""" + parser = ReActOutputParser() + text = "" + + steps = parser.parse(text) + assert len(steps) == 0 + + def test_no_thought_prefix(self): + """Test parsing when there's no thought prefix.""" + parser = ReActOutputParser() + text = """This is some text without any prefixes""" + + steps = parser.parse(text) + assert len(steps) == 0 + + def test_multiline_content(self): + """Test parsing when content spans multiple lines.""" + parser = ReActOutputParser() + text = """Thought: I need to analyze this data. +The data seems to have multiple entries. +I should process each one. +Action: process_data +Action Input: { + "entries": [ + {"id": 1, "value": "first"}, + {"id": 2, "value": "second"} + ], + "options": { + "sort": true, + "filter": false + } +}""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought.startswith("I need to analyze this data.") + assert "multiple entries" in steps[0].thought + assert steps[0].action == "process_data" + assert isinstance(steps[0].action_input, dict) + assert len(steps[0].action_input["entries"]) == 2 + assert steps[0].action_input["options"]["sort"] is True + + def test_whitespace_handling(self): + """Test parsing with various whitespace patterns.""" + parser = ReActOutputParser() + text = """ Thought: I need to calculate something. + Action: calculator + Action Input: {"a": 1, "b": 2} + + Observation: 3 """ + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I need to calculate something." + assert steps[0].action == "calculator" + assert steps[0].action_input == {"a": 1, "b": 2} + assert steps[0].observation == "3" + + def test_get_final_output_without_terminate(self): + """Test get_final_output when there's no terminate action.""" + parser = ReActOutputParser() + text = """Thought: I need to check the weather. +Action: weather_api +Action Input: {"location": "New York"} +Observation: Sunny, 75°F""" + + steps = parser.parse(text) + + final_output = parser.get_final_output(steps) + assert final_output is None + + def test_custom_terminate_action(self): + """Test with a custom terminate action.""" + parser = ReActOutputParser(terminate_action="end_task") + text = """Thought: I'm finished with the calculation. +Action: end_task +Action Input: {"output": "The result is 42"}""" + + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].is_terminal is True + + final_output = parser.get_final_output(steps) + assert final_output == "The result is 42" + + def test_example_in_prompt(self): + """Test the specific example provided in the prompt.""" + parser = ReActOutputParser() + text = """Thought: First, I need to calculate the product of 10 and 99 using \ +the simple_calculator tool. +Action: simple_calculator +Action Input: {"first_number": 10, "second_number": 99, "operator": "*"} +Observation: 990 +Thought: Now that I have the product, I need to count the number of files in the /tmp directory. +Action: count_directory_files +Action Input: {"path": "/tmp"} +Observation: 42 +Thought: I have successfully calculated the product and counted the files in /tmp. The task is complete. +Action: terminate +Action Input: {"output": "The product of 10 and 99 is 990, and there are 42 files in /tmp."}""" # noqa + + steps = parser.parse(text) + + assert len(steps) == 3 + + # First step + assert ( + steps[0].thought + == "First, I need to calculate the product of 10 and 99 using the simple_calculator tool." # noqa + ) + assert steps[0].action == "simple_calculator" + assert steps[0].action_input == { + "first_number": 10, + "second_number": 99, + "operator": "*", + } + assert steps[0].observation == "990" + + # Second step + assert ( + steps[1].thought + == "Now that I have the product, I need to count the number of files in the /tmp directory." # noqa + ) + assert steps[1].action == "count_directory_files" + assert steps[1].action_input == {"path": "/tmp"} + assert steps[1].observation == "42" + + # Third step + assert ( + steps[2].thought + == "I have successfully calculated the product and counted the files in /tmp. The task is complete." # noqa + ) + assert steps[2].action == "terminate" + assert steps[2].action_input == { + "output": "The product of 10 and 99 is 990, and there are 42 files in /tmp." + } + assert steps[2].is_terminal is True + + # Final output + final_output = parser.get_final_output(steps) + assert ( + final_output + == "The product of 10 and 99 is 990, and there are 42 files in /tmp." + ) + + def test_file_create(self): + """Test parsing with file creation.""" + parser = ReActOutputParser() + text = """Thought: I need to create a new file. +Action: CreateFile +Action Input: CreateFile(filepath="hello_world.py"): +``` +print("Hello, world!") +```""" + steps = parser.parse(text) + + assert len(steps) == 1 + assert steps[0].thought == "I need to create a new file." + assert steps[0].action == "CreateFile" + assert ( + steps[0].action_input + == """CreateFile(filepath="hello_world.py"): +``` +print("Hello, world!") +```""" + ) + assert steps[0].observation is None + assert steps[0].is_terminal is False diff --git a/packages/dbgpt-core/src/dbgpt/rag/retriever/time_weighted.py b/packages/dbgpt-core/src/dbgpt/rag/retriever/time_weighted.py index 4aa331743..cf150851c 100644 --- a/packages/dbgpt-core/src/dbgpt/rag/retriever/time_weighted.py +++ b/packages/dbgpt-core/src/dbgpt/rag/retriever/time_weighted.py @@ -1,8 +1,9 @@ -"""Time weighted retriever.""" +"""Time weighted retriever with external storage support.""" import datetime +import logging from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Protocol, Tuple from dbgpt.core import Chunk from dbgpt.rag.retriever.rerank import Ranker @@ -12,6 +13,31 @@ from dbgpt.storage.vector_store.filters import MetadataFilters from .embedding import EmbeddingRetriever +logger = logging.getLogger(__name__) + + +class DocumentStorage(Protocol): + """Protocol for external document storage.""" + + def get_all_documents(self) -> List[Chunk]: + """Get all documents from storage. + + Returns: + List of document chunks + """ + ... + + def save_documents(self, documents: List[Chunk]) -> bool: + """Save documents to storage. + + Args: + documents: List of document chunks to save + + Returns: + Boolean indicating success + """ + ... + def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: """Get the hours passed between two datetime objects.""" @@ -19,7 +45,7 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f class TimeWeightedEmbeddingRetriever(EmbeddingRetriever): - """Time weighted embedding retriever.""" + """Time weighted embedding retriever with external storage support.""" def __init__( self, @@ -28,14 +54,18 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever): query_rewrite: Optional[QueryRewrite] = None, rerank: Optional[Ranker] = None, decay_rate: float = 0.01, + external_storage: Optional[DocumentStorage] = None, ): """Initialize TimeWeightedEmbeddingRetriever. Args: index_store (IndexStoreBase): vector store connector - top_k (int): top k - query_rewrite (Optional[QueryRewrite]): query rewrite - rerank (Ranker): rerank + top_k (int): top k results to retrieve + query_rewrite (Optional[QueryRewrite]): query rewrite component + rerank (Ranker): reranking component + decay_rate (float): rate at which relevance decays over time + external_storage (Optional[DocumentStorage]): external storage for + persistence """ super().__init__( index_store=index_store, @@ -49,89 +79,288 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever): self.default_salience: Optional[float] = None self._top_k = top_k self._k = 4 + self._external_storage = external_storage + self._use_vector_store_only = False + + # Initialize memory stream + self._initialize_memory_stream() + + def _initialize_memory_stream(self) -> None: + """Initialize memory stream from external storage.""" + if self._external_storage: + try: + self.memory_stream = self._external_storage.get_all_documents() + logger.info( + "Loaded memory stream from external storage with " + f"{len(self.memory_stream)} documents" + ) + return + except Exception as e: + logger.error(f"Failed to load memory stream from external storage: {e}") + + # If external storage is not available or loading failed, operate in vector + # store only mode + self._use_vector_store_only = True + logger.info("No memory stream available. Operating in vector store only mode.") + + def _save_memory_stream(self) -> None: + """Save memory stream to external storage.""" + if self._external_storage and self.memory_stream: + try: + success = self._external_storage.save_documents(self.memory_stream) + if success: + logger.debug( + f"Saved {len(self.memory_stream)} documents to external storage" + ) + else: + logger.warning("Failed to save documents to external storage") + except Exception as e: + logger.error(f"Error saving documents to external storage: {e}") def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]: - """Load document in vector database. + """Load document chunks into vector database. Args: - - chunks: document chunks. - Return chunk ids. + chunks: document chunks to be loaded + **kwargs: additional parameters including current_time + + Returns: + List of chunk IDs """ current_time: Optional[datetime.datetime] = kwargs.get("current_time") # type: ignore # noqa if current_time is None: current_time = datetime.datetime.now() + # Avoid mutating input documents dup_docs = [deepcopy(d) for d in chunks] + + # Generate buffer indices for new documents for i, doc in enumerate(dup_docs): if doc.metadata.get("last_accessed_at") is None: doc.metadata["last_accessed_at"] = current_time if "created_at" not in doc.metadata: doc.metadata["created_at"] = current_time doc.metadata["buffer_idx"] = len(self.memory_stream) + i + + # Add to memory stream self.memory_stream.extend(dup_docs) + + # Save memory stream after adding new documents + self._save_memory_stream() + + # Add to vector store return self._index_store.load_document(dup_docs) def _retrieve( self, query: str, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: - """Retrieve knowledge chunks. + """Retrieve knowledge chunks based on query. Args: query (str): query text - filters: metadata filters. - Return: - List[Chunk]: list of chunks + filters: metadata filters + + Returns: + List[Chunk]: list of relevant chunks """ current_time = datetime.datetime.now() - docs_and_scores = { - doc.metadata["buffer_idx"]: (doc, self.default_salience) - for doc in self.memory_stream[-self._k :] - } + + if self._use_vector_store_only: + # If operating in vector store only mode, perform similar search + # and apply time weighting directly on the results + docs_and_scores = self._index_store.similar_search_with_scores( + query, topk=self._top_k, score_threshold=0, filters=filters + ) + + # Apply time weighting to vector store results + rescored_docs = [] + for doc in docs_and_scores: + # Extract time information from metadata + if "last_accessed_at" in doc.metadata and "created_at" in doc.metadata: + rescored_docs.append( + (doc, self._get_combined_score(doc, doc.score, current_time)) + ) + else: + # If time info not available, just use vector similarity + rescored_docs.append((doc, doc.score)) + + # Sort by combined score + rescored_docs.sort(key=lambda x: x[1], reverse=True) + + # Return top k results + return [doc for doc, _ in rescored_docs[: self._k]] + + # Normal operation with memory stream + # Get the most recent documents + docs_and_scores = {} + if self.memory_stream: + docs_and_scores = { + doc.metadata["buffer_idx"]: (doc, self.default_salience) + for doc in self.memory_stream[-self._k :] + if "buffer_idx" in doc.metadata + } + # If a doc is considered salient, update the salience score - docs_and_scores.update(self.get_salient_docs(query)) + docs_and_scores.update(self.get_salient_docs(query, filters)) + + # If no documents found, fall back to vector store query with time weighting + if not docs_and_scores: + return self._retrieve_vector_store_only(query, filters, current_time) + rescored_docs = [ (doc, self._get_combined_score(doc, relevance, current_time)) for doc, relevance in docs_and_scores.values() ] + rescored_docs.sort(key=lambda x: x[1], reverse=True) result = [] + # Ensure frequently accessed memories aren't forgotten for doc, _ in rescored_docs[: self._k]: - # TODO: Update vector store doc once `update` method is exposed. - buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]] - buffered_doc.metadata["last_accessed_at"] = current_time - result.append(buffered_doc) + if "buffer_idx" in doc.metadata: + buffer_idx = doc.metadata["buffer_idx"] + if 0 <= buffer_idx < len(self.memory_stream): + buffered_doc = self.memory_stream[buffer_idx] + buffered_doc.metadata["last_accessed_at"] = current_time + result.append(buffered_doc) + else: + # If buffer_idx is invalid, still return the document from vector + # store + result.append(doc) + else: + result.append(doc) + + # Save memory stream after updating access times + self._save_memory_stream() + return result + def _retrieve_vector_store_only( + self, + query: str, + filters: Optional[MetadataFilters], + current_time: datetime.datetime, + ) -> List[Chunk]: + """Retrieve and apply time weighting using only vector store. + + Args: + query: User query text + filters: Optional metadata filters + current_time: Current time for calculating decay + + Returns: + List of relevant chunks + """ + # Get documents from vector store + docs = self._index_store.similar_search_with_scores( + query, topk=self._top_k, score_threshold=0, filters=filters + ) + + # Apply time weighting + rescored_docs = [] + for doc in docs: + # Update last_accessed_at time if it exists + if "last_accessed_at" in doc.metadata: + last_accessed_time = doc.metadata["last_accessed_at"] + hours_passed = _get_hours_passed(current_time, last_accessed_time) + time_score = (1.0 - self.decay_rate) ** hours_passed + # Combine with vector similarity score + combined_score = doc.score + time_score + rescored_docs.append((doc, combined_score)) + else: + # Just use vector similarity if no time data + rescored_docs.append((doc, doc.score)) + + # Sort by combined score + rescored_docs.sort(key=lambda x: x[1], reverse=True) + + # Return top results + return [doc for doc, _ in rescored_docs[: self._k]] + def _get_combined_score( self, chunk: Chunk, vector_relevance: Optional[float], current_time: datetime.datetime, ) -> float: - """Return the combined score for a document.""" - hours_passed = _get_hours_passed( - current_time, - chunk.metadata["last_accessed_at"], - ) + """Calculate combined score for a document based on time decay and relevance. + + Args: + chunk: The document chunk + vector_relevance: Vector similarity score + current_time: Current time for calculating decay + + Returns: + Combined score value + """ + # Default last_accessed_at to creation time if not present + last_accessed_at = chunk.metadata.get("last_accessed_at") + if last_accessed_at is None: + last_accessed_at = chunk.metadata.get("created_at", current_time) + + hours_passed = _get_hours_passed(current_time, last_accessed_at) score = (1.0 - self.decay_rate) ** hours_passed + for key in self.other_score_keys: if key in chunk.metadata: score += chunk.metadata[key] + if vector_relevance is not None: score += vector_relevance + return score - def get_salient_docs(self, query: str) -> Dict[int, Tuple[Chunk, float]]: - """Return documents that are salient to the query.""" + def get_salient_docs( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> Dict[int, Tuple[Chunk, float]]: + """Find documents that are relevant to the query. + + Args: + query: User query text + filters: Optional metadata filters + + Returns: + Dictionary mapping buffer indices to (document, score) tuples + """ docs_and_scores: List[Chunk] docs_and_scores = self._index_store.similar_search_with_scores( - query, topk=self._top_k, score_threshold=0 + query, topk=self._top_k, score_threshold=0, filters=filters ) + results = {} for ck in docs_and_scores: if "buffer_idx" in ck.metadata: buffer_idx = ck.metadata["buffer_idx"] - doc = self.memory_stream[buffer_idx] - results[buffer_idx] = (doc, ck.score) + # Add error handling to prevent IndexError + if 0 <= buffer_idx < len(self.memory_stream): + doc = self.memory_stream[buffer_idx] + results[buffer_idx] = (doc, ck.score) + else: + # If buffer_idx is out of range, still include document but with + # original + results[buffer_idx] = (ck, ck.score) + return results + + def set_external_storage(self, storage: DocumentStorage) -> None: + """Set external storage and reload memory stream. + + Args: + storage: External document storage + """ + self._external_storage = storage + self._use_vector_store_only = False + self._initialize_memory_stream() + + def sync_with_external_storage(self) -> None: + """Sync memory stream with external storage.""" + if self._external_storage: + try: + self.memory_stream = self._external_storage.get_all_documents() + self._use_vector_store_only = False + logger.info( + "Synced memory stream from external storage with " + f"{len(self.memory_stream)} documents" + ) + except Exception as e: + logger.error(f"Failed to sync memory stream from external storage: {e}") diff --git a/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py b/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py index 0b5406f33..b60e06730 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py @@ -15,7 +15,9 @@ from dbgpt.agent import ( AutoPlanChatManager, ConversableAgent, DefaultAWELLayoutManager, + EnhancedShortTermMemory, GptsMemory, + HybridMemory, LLMConfig, ResourceType, UserProxyAgent, @@ -31,6 +33,7 @@ from dbgpt.core.awel.flow.flow_factory import FlowCategory from dbgpt.core.interface.message import StorageConversation from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.model.cluster.client import DefaultLLMClient +from dbgpt.util.executor_utils import ExecutorFactory from dbgpt.util.json_utils import serialize from dbgpt.util.tracer import TracerManager from dbgpt_app.dbgpt_server import system_app @@ -131,12 +134,27 @@ class MultiAgents(BaseComponent, ABC): return self.gpts_app.app_detail(app_code) def get_or_build_agent_memory(self, conv_id: str, dbgpts_name: str) -> AgentMemory: - memory_key = f"{dbgpts_name}_{conv_id}" - if memory_key in self.agent_memory_map: - return self.agent_memory_map[memory_key] + from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory + from dbgpt_serve.rag.storage_manager import StorageManager + + executor = self.system_app.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() + + storage_manager = StorageManager.get_instance(self.system_app) + vector_store = storage_manager.create_vector_store(index_name="_agent_memory_") + embeddings = EmbeddingFactory.get_instance(self.system_app).create() + short_term_memory = EnhancedShortTermMemory( + embeddings, executor=executor, buffer_size=10 + ) + memory = HybridMemory.from_vstore( + vector_store, + embeddings=embeddings, + executor=executor, + short_term_memory=short_term_memory, + ) + agent_memory = AgentMemory(memory, gpts_memory=self.memory) - agent_memory = AgentMemory(gpts_memory=self.memory) - self.agent_memory_map[memory_key] = agent_memory return agent_memory async def agent_chat_v2( diff --git a/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/db_gpts_memory.py b/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/db_gpts_memory.py index d578f130f..1169ded5e 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/db_gpts_memory.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/db_gpts_memory.py @@ -84,7 +84,6 @@ class MetaDbGptsMessageMemory(GptsMessageMemory): def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]: db_results = self.gpts_message.get_by_agent(conv_id, agent) results = [] - db_results = sorted(db_results, key=lambda x: x.rounds) for item in db_results: results.append(GptsMessage.from_dict(item.__dict__)) return results @@ -120,3 +119,6 @@ class MetaDbGptsMessageMemory(GptsMessageMemory): return GptsMessage.from_dict(db_result.__dict__) else: return None + + def delete_by_conv_id(self, conv_id: str) -> None: + self.gpts_message.delete_chat_message(conv_id) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_messages_db.py b/packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_messages_db.py index 71da29a3f..acad1b5aa 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_messages_db.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_messages_db.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from typing import List, Optional @@ -14,6 +15,7 @@ from sqlalchemy import ( or_, ) +from dbgpt.agent.util.conv_utils import parse_conv_id from dbgpt.storage.metadata import BaseDao, Model @@ -111,19 +113,31 @@ class GptsMessagesDao(BaseDao): self, conv_id: str, agent: str ) -> Optional[List[GptsMessagesEntity]]: session = self.get_raw_session() + real_conv_id, _ = parse_conv_id(conv_id) gpts_messages = session.query(GptsMessagesEntity) if agent: gpts_messages = gpts_messages.filter( - GptsMessagesEntity.conv_id == conv_id + GptsMessagesEntity.conv_id.like(f"%{real_conv_id}%") ).filter( or_( GptsMessagesEntity.sender == agent, GptsMessagesEntity.receiver == agent, ) ) - result = gpts_messages.order_by(GptsMessagesEntity.rounds).all() + # Extract results first to apply custom sorting + results = gpts_messages.all() + + # Custom sorting based on conv_id suffix and rounds + def get_suffix_number(entity): + suffix_match = re.search(r"_(\d+)$", entity.conv_id) + if suffix_match: + return int(suffix_match.group(1)) + return 0 # Default for entries without a numeric suffix + + # Sort first by numeric suffix, then by rounds + sorted_results = sorted(results, key=lambda x: (get_suffix_number(x), x.rounds)) session.close() - return result + return sorted_results def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessagesEntity]]: session = self.get_raw_session() diff --git a/scripts/update_version_all.py b/scripts/update_version_all.py index f91517861..37f2d3c3b 100644 --- a/scripts/update_version_all.py +++ b/scripts/update_version_all.py @@ -53,12 +53,13 @@ from typing import List, Optional @dataclass class VersionChange: """Represents a single version change in a file.""" + file_path: Path file_type: str old_version: str new_version: str package_name: str - + def __str__(self): rel_path = self.file_path.as_posix() return f"{self.package_name:<20} {self.file_type:<12} {rel_path:<50} {self.old_version} -> {self.new_version}" @@ -66,15 +67,17 @@ class VersionChange: class VersionUpdater: """Class to handle version updates across the project.""" - + def __init__(self, new_version: str, root_dir: Path, args: argparse.Namespace): self.new_version = new_version self.root_dir = root_dir self.args = args self.changes: List[VersionChange] = [] # Support: X.Y.Z, X.Y.ZrcN, X.Y.Z-alpha.N, X.Y.Z-beta.N, X.Y.Z-rc.N - self.version_pattern = re.compile(r"^\d+\.\d+\.\d+(-[a-zA-Z0-9.]+)?$|^\d+\.\d+\.\d+[a-zA-Z][a-zA-Z0-9.]*$") - + self.version_pattern = re.compile( + r"^\d+\.\d+\.\d+(-[a-zA-Z0-9.]+)?$|^\d+\.\d+\.\d+[a-zA-Z][a-zA-Z0-9.]*$" + ) + def validate_version(self) -> bool: """Validate the version format.""" if not self.version_pattern.match(self.new_version): @@ -83,11 +86,11 @@ class VersionUpdater: print(" - Pre-release: 0.7.0rc0, 0.7.0-beta.1, 1.0.0-alpha.2") return False return True - + def find_main_config(self) -> Optional[Path]: """Find the main project configuration file.""" root_config = self.root_dir / "pyproject.toml" - + if not root_config.exists(): # Try to find it in subdirectories possible_files = list(self.root_dir.glob("**/pyproject.toml")) @@ -97,133 +100,147 @@ class VersionUpdater: else: print("Error: Could not find the project configuration file") return None - + return root_config - + def collect_toml_changes(self, file_path: Path, package_name: str) -> bool: """Collect version changes needed in a TOML file.""" try: # Read the entire file content to preserve formatting with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # Parse the TOML content to extract version information with open(file_path, "rb") as f: data = tomli.load(f) - + # Check for project.version or tool.poetry.version if "project" in data and "version" in data["project"]: old_version = data["project"]["version"] - self.changes.append(VersionChange( - file_path=file_path, - file_type="pyproject.toml", - old_version=old_version, - new_version=self.new_version, - package_name=package_name - )) + self.changes.append( + VersionChange( + file_path=file_path, + file_type="pyproject.toml", + old_version=old_version, + new_version=self.new_version, + package_name=package_name, + ) + ) return True - + # Check for tool.poetry.version - elif "tool" in data and "poetry" in data["tool"] and "version" in data["tool"]["poetry"]: + elif ( + "tool" in data + and "poetry" in data["tool"] + and "version" in data["tool"]["poetry"] + ): old_version = data["tool"]["poetry"]["version"] - self.changes.append(VersionChange( - file_path=file_path, - file_type="pyproject.toml", - old_version=old_version, - new_version=self.new_version, - package_name=package_name - )) + self.changes.append( + VersionChange( + file_path=file_path, + file_type="pyproject.toml", + old_version=old_version, + new_version=self.new_version, + package_name=package_name, + ) + ) return True - + return False - + except Exception as e: print(f"Error analyzing {file_path}: {str(e)}") return False - + def collect_setup_py_changes(self, file_path: Path, package_name: str) -> bool: """Collect version changes needed in a setup.py file.""" try: with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # Find version pattern - more flexible to detect different formats version_pattern = r'version\s*=\s*["\']([^"\']+)["\']' match = re.search(version_pattern, content) - + if match: old_version = match.group(1) - self.changes.append(VersionChange( - file_path=file_path, - file_type="setup.py", - old_version=old_version, - new_version=self.new_version, - package_name=package_name - )) + self.changes.append( + VersionChange( + file_path=file_path, + file_type="setup.py", + old_version=old_version, + new_version=self.new_version, + package_name=package_name, + ) + ) return True - + return False - + except Exception as e: print(f"Error analyzing {file_path}: {str(e)}") return False - + def collect_version_py_changes(self, file_path: Path, package_name: str) -> bool: """Collect version changes needed in a _version.py file.""" try: with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # Collect version pattern - more flexible to detect different formats # e.g. version = "0.7.0" version_pattern = r'version\s*=\s*["\']([^"\']+)["\']' match = re.search(version_pattern, content) - + if match: old_version = match.group(1) - self.changes.append(VersionChange( - file_path=file_path, - file_type="_version.py", - old_version=old_version, - new_version=self.new_version, - package_name=package_name - )) + self.changes.append( + VersionChange( + file_path=file_path, + file_type="_version.py", + old_version=old_version, + new_version=self.new_version, + package_name=package_name, + ) + ) return True - + return False - + except Exception as e: print(f"Error analyzing {file_path}: {str(e)}") return False - + def collect_json_changes(self, file_path: Path, package_name: str) -> bool: """Collect version changes needed in a JSON file.""" try: with open(file_path, "r", encoding="utf-8") as f: content = f.read() data = json.loads(content) - + if "version" in data: old_version = data["version"] - self.changes.append(VersionChange( - file_path=file_path, - file_type="package.json", - old_version=old_version, - new_version=self.new_version, - package_name=package_name - )) + self.changes.append( + VersionChange( + file_path=file_path, + file_type="package.json", + old_version=old_version, + new_version=self.new_version, + package_name=package_name, + ) + ) return True - + return False - + except Exception as e: print(f"Error analyzing {file_path}: {str(e)}") return False - + def find_workspace_members(self, workspace_members: List[str]) -> List[Path]: """Find all workspace member directories.""" members = [] - + for pattern in workspace_members: # Handle glob patterns if "*" in pattern: @@ -233,68 +250,72 @@ class VersionUpdater: path = self.root_dir / pattern if path.exists(): members.append(path) - + return members - + def collect_all_changes(self) -> bool: """Collect all version changes needed across the project.""" # Find main project configuration root_config = self.find_main_config() if not root_config: return False - + # Start with the main config file self.collect_toml_changes(root_config, "root-project") - + # Find and parse workspace members from configuration workspace_members = [] try: with open(root_config, "rb") as f: data = tomli.load(f) - - if "tool" in data and "uv" in data["tool"] and "workspace" in data["tool"]["uv"]: + + if ( + "tool" in data + and "uv" in data["tool"] + and "workspace" in data["tool"]["uv"] + ): workspace_members = data["tool"]["uv"]["workspace"].get("members", []) except Exception as e: print(f"Warning: Could not parse workspace members: {str(e)}") - + # Find all package directories package_dirs = self.find_workspace_members(workspace_members) print(f"Found {len(package_dirs)} workspace packages to check") - + # Check each package directory for version files for pkg_dir in package_dirs: package_name = pkg_dir.name - + # Skip if filter is applied and doesn't match if self.args.filter and self.args.filter not in package_name: continue - + # Check for pyproject.toml pkg_toml = pkg_dir / "pyproject.toml" if pkg_toml.exists(): self.collect_toml_changes(pkg_toml, package_name) - + # Check for setup.py setup_py = pkg_dir / "setup.py" if setup_py.exists(): self.collect_setup_py_changes(setup_py, package_name) - + # Check for package.json package_json = pkg_dir / "package.json" if package_json.exists(): self.collect_json_changes(package_json, package_name) - + # Check for _version.py files version_py_files = list(pkg_dir.glob("**/_version.py")) for version_py in version_py_files: self.collect_version_py_changes(version_py, package_name) - + return len(self.changes) > 0 - + def apply_changes(self) -> int: """Apply all collected changes.""" applied_count = 0 - + for change in self.changes: try: if change.file_type == "pyproject.toml": @@ -305,157 +326,158 @@ class VersionUpdater: self._update_json_file(change.file_path) elif change.file_type == "_version.py": self._update_version_py_file(change.file_path) - + applied_count += 1 print(f"✅ Updated {change.file_path}") - + except Exception as e: print(f"❌ Failed to update {change.file_path}: {str(e)}") - + return applied_count - + def _update_toml_file(self, file_path: Path) -> None: """Update version in a TOML file using regex to preserve formatting.""" with open(file_path, "r", encoding="utf-8") as f: content = f.read() - - updated = False - - # Update project.version - project_version_pattern = r'(\[project\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' - if re.search(project_version_pattern, content, re.DOTALL): + updated = False + + # Update project.version + project_version_pattern = ( + r'(\[project\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' + ) + if re.search(project_version_pattern, content, re.DOTALL): project_pattern = r'(\[project\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' content = re.sub( project_pattern, lambda m: m.group(0).replace(m.group(2), self.new_version), content, - flags=re.DOTALL + flags=re.DOTALL, ) updated = True - - poetry_version_pattern = r'(\[tool\.poetry\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' + + poetry_version_pattern = ( + r'(\[tool\.poetry\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' + ) if re.search(poetry_version_pattern, content, re.DOTALL): - poetry_pattern = r'(\[tool\.poetry\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' + poetry_pattern = ( + r'(\[tool\.poetry\][^\[]*?version\s*=\s*["\'](.*?)["\']\s*)' + ) content = re.sub( poetry_pattern, lambda m: m.group(0).replace(m.group(2), self.new_version), content, - flags=re.DOTALL + flags=re.DOTALL, ) updated = True - + if not updated: version_line_pattern = r'(^version\s*=\s*["\'](.*?)["\']\s*$)' content = re.sub( version_line_pattern, lambda m: m.group(0).replace(m.group(2), self.new_version), content, - flags=re.MULTILINE + flags=re.MULTILINE, ) - + with open(file_path, "w", encoding="utf-8") as f: f.write(content) - + def _update_setup_py_file(self, file_path: Path) -> None: """Update version in a setup.py file.""" with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # Find and replace version version_pattern = r'(version\s*=\s*["\'])([^"\']+)(["\'])' updated_content = re.sub( - version_pattern, - rf'\g<1>{self.new_version}\g<3>', - content + version_pattern, rf"\g<1>{self.new_version}\g<3>", content ) - + with open(file_path, "w", encoding="utf-8") as f: f.write(updated_content) - + def _update_version_py_file(self, file_path: Path) -> None: """Update version in a _version.py file.""" with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + version_pattern = r'(version\s*=\s*["\'])([^"\']+)(["\'])' updated_content = re.sub( - version_pattern, - rf'\g<1>{self.new_version}\g<3>', - content + version_pattern, rf"\g<1>{self.new_version}\g<3>", content ) - + with open(file_path, "w", encoding="utf-8") as f: f.write(updated_content) - + def _update_json_file(self, file_path: Path) -> None: """Update version in a JSON file while preserving formatting.""" with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + version_pattern = r'("version"\s*:\s*")([^"]+)(")' updated_content = re.sub( - version_pattern, - rf'\g<1>{self.new_version}\g<3>', - content + version_pattern, rf"\g<1>{self.new_version}\g<3>", content ) - + with open(file_path, "w", encoding="utf-8") as f: f.write(updated_content) - + def show_changes(self) -> None: """Display the collected changes.""" if not self.changes: print("No changes to apply.") return - + print("\n" + "=" * 100) print(f"Version changes to apply: {self.new_version}") print("=" * 100) print(f"{'Package':<20} {'File Type':<12} {'Path':<50} {'Version Change'}") print("-" * 100) - + for change in self.changes: print(str(change)) - + print("=" * 100) print(f"Total: {len(self.changes)} file(s) to update") print("=" * 100) - + def prompt_for_confirmation(self) -> bool: """Prompt the user for confirmation.""" if self.args.yes: return True - + response = input("\nApply these changes? [y/N]: ").strip().lower() - return response in ['y', 'yes'] - + return response in ["y", "yes"] + def run(self) -> bool: """Run the updater.""" if not self.validate_version(): return False - + # Collect all changes if not self.collect_all_changes(): print("No files found that need version updates.") return False - + # Show the changes self.show_changes() - + # If dry run, exit now if self.args.dry_run: print("\nDry run complete. No changes were applied.") return True - + # Prompt for confirmation if not self.prompt_for_confirmation(): print("\nOperation cancelled. No changes were applied.") return False - + # Apply the changes applied_count = self.apply_changes() - print(f"\n🎉 Version update complete! Updated {applied_count} files to version {self.new_version}") + print( + f"\n🎉 Version update complete! Updated {applied_count} files to version {self.new_version}" + ) return True @@ -464,32 +486,39 @@ def parse_args(): parser = argparse.ArgumentParser( description="Update version numbers across the dbgpt-mono project", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__.split("\n\n")[2] # Extract usage examples + epilog=__doc__.split("\n\n")[2], # Extract usage examples ) - - parser.add_argument("version", help="New version number (supports standard and pre-release formats)") - parser.add_argument("-y", "--yes", action="store_true", help="Apply changes without confirmation") - parser.add_argument("-d", "--dry-run", action="store_true", help="Only show changes without applying them") - parser.add_argument("-f", "--filter", help="Only update packages containing this string") - + + parser.add_argument( + "version", help="New version number (supports standard and pre-release formats)" + ) + parser.add_argument( + "-y", "--yes", action="store_true", help="Apply changes without confirmation" + ) + parser.add_argument( + "-d", + "--dry-run", + action="store_true", + help="Only show changes without applying them", + ) + parser.add_argument( + "-f", "--filter", help="Only update packages containing this string" + ) + return parser.parse_args() def main(): """Main entry point for the script.""" args = parse_args() - + # Initialize the updater - updater = VersionUpdater( - new_version=args.version, - root_dir=Path("../"), - args=args - ) - + updater = VersionUpdater(new_version=args.version, root_dir=Path("../"), args=args) + # Run the updater success = updater.run() sys.exit(0 if success else 1) if __name__ == "__main__": - main() \ No newline at end of file + main()