mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
feat: Support intent detection (#1588)
This commit is contained in:
parent
73d175a127
commit
a88af6f87d
6
Makefile
6
Makefile
@ -47,7 +47,7 @@ fmt: setup ## Format Python code
|
||||
$(VENV_BIN)/blackdoc examples
|
||||
# TODO: Use flake8 to enforce Python style guide.
|
||||
# https://flake8.pycqa.org/en/latest/
|
||||
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/
|
||||
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/ dbgpt/experimental/
|
||||
# TODO: More package checks with flake8.
|
||||
|
||||
.PHONY: fmt-check
|
||||
@ -56,7 +56,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
|
||||
$(VENV_BIN)/isort --check-only --extend-skip="examples/notebook" examples
|
||||
$(VENV_BIN)/black --check --extend-exclude="examples/notebook" .
|
||||
$(VENV_BIN)/blackdoc --check dbgpt examples
|
||||
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/
|
||||
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/ dbgpt/experimental/
|
||||
|
||||
.PHONY: pre-commit
|
||||
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
|
||||
@ -72,7 +72,7 @@ test-doc: $(VENV)/.testenv ## Run doctests
|
||||
.PHONY: mypy
|
||||
mypy: $(VENV)/.testenv ## Run mypy checks
|
||||
# https://github.com/python/mypy
|
||||
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/
|
||||
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/ dbgpt/experimental/
|
||||
# rag depends on core and storage, so we not need to check it again.
|
||||
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
|
||||
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
|
||||
|
3
assets/schema/upgrade/v_0_5_7/upgrade_to_v0.5.7.sql
Normal file
3
assets/schema/upgrade/v_0_5_7/upgrade_to_v0.5.7.sql
Normal file
@ -0,0 +1,3 @@
|
||||
USE dbgpt;
|
||||
ALTER TABLE dbgpt_serve_flow
|
||||
ADD COLUMN `define_type` varchar(32) null comment 'Flow define type(json or python)' after `version`;
|
395
assets/schema/upgrade/v_0_5_7/v0.5.6.sql
Normal file
395
assets/schema/upgrade/v_0_5_7/v0.5.6.sql
Normal file
@ -0,0 +1,395 @@
|
||||
-- Full SQL of v0.5.6, please not modify this file(It must be same as the file in the release package)
|
||||
|
||||
CREATE
|
||||
DATABASE IF NOT EXISTS dbgpt;
|
||||
use dbgpt;
|
||||
|
||||
-- For alembic migration tool
|
||||
CREATE TABLE IF NOT EXISTS `alembic_version`
|
||||
(
|
||||
version_num VARCHAR(32) NOT NULL,
|
||||
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
|
||||
) DEFAULT CHARSET=utf8mb4 ;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `knowledge_space`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
|
||||
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
|
||||
`desc` varchar(500) NOT NULL COMMENT 'description',
|
||||
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
|
||||
`context` TEXT DEFAULT NULL COMMENT 'context argument',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `knowledge_document`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
|
||||
`chunk_size` int NOT NULL COMMENT 'chunk size',
|
||||
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
|
||||
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
|
||||
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
||||
`result` TEXT NULL COMMENT 'knowledge content',
|
||||
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
||||
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `document_chunk`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||
`document_id` int NOT NULL COMMENT 'document parent id',
|
||||
`content` longtext NOT NULL COMMENT 'chunk content',
|
||||
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail';
|
||||
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `connect_config`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`db_type` varchar(255) NOT NULL COMMENT 'db type',
|
||||
`db_name` varchar(255) NOT NULL COMMENT 'db name',
|
||||
`db_path` varchar(255) DEFAULT NULL COMMENT 'file db path',
|
||||
`db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)',
|
||||
`db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)',
|
||||
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
|
||||
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
|
||||
`comment` text COMMENT 'db comment',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_db` (`db_name`),
|
||||
KEY `idx_q_db_type` (`db_type`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `chat_history`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
|
||||
`summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
||||
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
||||
`message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
UNIQUE KEY `conv_uid` (`conv_uid`),
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `chat_history_message`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
|
||||
`index` int NOT NULL COMMENT 'Message index',
|
||||
`round_index` int NOT NULL COMMENT 'Round of conversation',
|
||||
`message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
UNIQUE KEY `message_uid_index` (`conv_uid`, `index`),
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `chat_feed_back`
|
||||
(
|
||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID',
|
||||
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation',
|
||||
`score` int(1) DEFAULT NULL COMMENT 'Score of user',
|
||||
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category',
|
||||
`question` longtext DEFAULT NULL COMMENT 'User question',
|
||||
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
|
||||
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
|
||||
KEY `idx_conv` (`conv_uid`,`conv_index`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table';
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `my_plugin`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant',
|
||||
`user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code',
|
||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name',
|
||||
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
||||
`file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name',
|
||||
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
|
||||
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `name` (`name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `plugin_hub`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
|
||||
`description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description',
|
||||
`author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author',
|
||||
`email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email',
|
||||
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
|
||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||
`storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel',
|
||||
`storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url',
|
||||
`download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time',
|
||||
`installed` int DEFAULT NULL COMMENT 'plugin already installed count',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `name` (`name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table';
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `prompt_manage`
|
||||
(
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
|
||||
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
|
||||
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
|
||||
`prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
||||
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
||||
`input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))',
|
||||
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
|
||||
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
|
||||
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
|
||||
`prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description',
|
||||
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`),
|
||||
KEY `gmt_created_idx` (`gmt_created`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `gpts_conversations` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
|
||||
`user_goal` text NOT NULL COMMENT 'User''s goals content',
|
||||
`gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name',
|
||||
`state` varchar(255) DEFAULT NULL COMMENT 'The gpts state',
|
||||
`max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round',
|
||||
`auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count',
|
||||
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
`team_mode` varchar(255) NULL COMMENT 'agent team work mode',
|
||||
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts_conversations` (`conv_id`),
|
||||
KEY `idx_gpts_name` (`gpts_name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations";
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `gpts_instance` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
|
||||
`gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe',
|
||||
`resource_db` text COMMENT 'List of structured database names contained in the current gpts',
|
||||
`resource_internet` text COMMENT 'Is it possible to retrieve information from the internet',
|
||||
`resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts',
|
||||
`gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts',
|
||||
`gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts',
|
||||
`language` varchar(100) DEFAULT NULL COMMENT 'gpts language',
|
||||
`user_code` varchar(255) NOT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
`team_mode` varchar(255) NOT NULL COMMENT 'Team work mode',
|
||||
`is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts` (`gpts_name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance";
|
||||
|
||||
CREATE TABLE `gpts_messages` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
|
||||
`sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn',
|
||||
`receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn',
|
||||
`model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model',
|
||||
`rounds` int(11) NOT NULL COMMENT 'dialogue turns',
|
||||
`content` text COMMENT 'Content of the speech',
|
||||
`current_goal` text COMMENT 'The target corresponding to the current message',
|
||||
`context` text COMMENT 'Current conversation context',
|
||||
`review_info` text COMMENT 'Current conversation review info',
|
||||
`action_report` text COMMENT 'Current conversation action report',
|
||||
`role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message";
|
||||
|
||||
|
||||
CREATE TABLE `gpts_plans` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record',
|
||||
`sub_task_num` int(11) NOT NULL COMMENT 'Subtask number',
|
||||
`sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title',
|
||||
`sub_task_content` text NOT NULL COMMENT 'subtask content',
|
||||
`sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks',
|
||||
`resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name',
|
||||
`rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3',
|
||||
`agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents',
|
||||
`retry_times` int(11) DEFAULT NULL COMMENT 'number of retries',
|
||||
`max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries',
|
||||
`state` varchar(255) DEFAULT NULL COMMENT 'subtask status',
|
||||
`result` longtext COMMENT 'subtask result',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan";
|
||||
|
||||
-- dbgpt.dbgpt_serve_flow definition
|
||||
CREATE TABLE `dbgpt_serve_flow` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id',
|
||||
`uid` varchar(128) NOT NULL COMMENT 'Unique id',
|
||||
`dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id',
|
||||
`name` varchar(128) DEFAULT NULL COMMENT 'Flow name',
|
||||
`flow_data` text COMMENT 'Flow data, JSON format',
|
||||
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time',
|
||||
`gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time',
|
||||
`flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category',
|
||||
`description` varchar(512) DEFAULT NULL COMMENT 'Flow description',
|
||||
`state` varchar(32) DEFAULT NULL COMMENT 'Flow state',
|
||||
`error_message` varchar(512) NULL comment 'Error message',
|
||||
`source` varchar(64) DEFAULT NULL COMMENT 'Flow source',
|
||||
`source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url',
|
||||
`version` varchar(32) DEFAULT NULL COMMENT 'Flow version',
|
||||
`label` varchar(128) DEFAULT NULL COMMENT 'Flow label',
|
||||
`editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_uid` (`uid`),
|
||||
KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`),
|
||||
KEY `ix_dbgpt_serve_flow_uid` (`uid`),
|
||||
KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`),
|
||||
KEY `ix_dbgpt_serve_flow_user_name` (`user_name`),
|
||||
KEY `ix_dbgpt_serve_flow_name` (`name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
-- dbgpt.gpts_app definition
|
||||
CREATE TABLE `gpts_app` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
|
||||
`app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
|
||||
`app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe',
|
||||
`language` varchar(100) NOT NULL COMMENT 'gpts language',
|
||||
`team_mode` varchar(255) NOT NULL COMMENT 'Team work mode',
|
||||
`team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on',
|
||||
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
`icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts_app` (`app_name`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE `gpts_app_collection` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
|
||||
`user_code` int(11) NOT NULL COMMENT 'user code',
|
||||
`sys_code` varchar(255) NOT NULL COMMENT 'system app code',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_app_code` (`app_code`),
|
||||
KEY `idx_user_code` (`user_code`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections";
|
||||
|
||||
-- dbgpt.gpts_app_detail definition
|
||||
CREATE TABLE `gpts_app_detail` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
|
||||
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
|
||||
`app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name',
|
||||
`agent_name` varchar(255) NOT NULL COMMENT ' Agent name',
|
||||
`node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id',
|
||||
`resources` text COMMENT 'Agent bind resource',
|
||||
`prompt_template` text COMMENT 'Agent bind template',
|
||||
`llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy',
|
||||
`llm_strategy_value` text COMMENT 'Agent use llm strategy value',
|
||||
`created_at` datetime DEFAULT NULL COMMENT 'create time',
|
||||
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
CREATE
|
||||
DATABASE IF NOT EXISTS EXAMPLE_1;
|
||||
use EXAMPLE_1;
|
||||
CREATE TABLE IF NOT EXISTS `users`
|
||||
(
|
||||
`id` int NOT NULL AUTO_INCREMENT,
|
||||
`username` varchar(50) NOT NULL COMMENT '用户名',
|
||||
`password` varchar(50) NOT NULL COMMENT '密码',
|
||||
`email` varchar(50) NOT NULL COMMENT '邮箱',
|
||||
`phone` varchar(20) DEFAULT NULL COMMENT '电话',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_username` (`username`) COMMENT '索引:按用户名查询'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表';
|
||||
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909');
|
||||
INSERT INTO users (username, password, email, phone)
|
||||
VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900');
|
@ -20,6 +20,7 @@ else:
|
||||
PositiveInt,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
WithJsonSchema,
|
||||
field_validator,
|
||||
model_validator,
|
||||
root_validator,
|
||||
|
@ -257,13 +257,6 @@ class BaseChat(ABC):
|
||||
def stream_call_reinforce_fn(self, text):
|
||||
return text
|
||||
|
||||
async def check_iterator_end(iterator):
|
||||
try:
|
||||
await asyncio.anext(iterator)
|
||||
return False # 迭代器还有下一个元素
|
||||
except StopAsyncIteration:
|
||||
return True # 迭代器已经执行结束
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
del metadata["prompt"]
|
||||
|
@ -18,6 +18,7 @@ from .operators.base import BaseOperator, WorkflowRunner
|
||||
from .operators.common_operator import (
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
BranchTaskType,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
@ -80,6 +81,7 @@ __all__ = [
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
"BranchTaskType",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"is_empty_data",
|
||||
|
@ -3,6 +3,7 @@
|
||||
DAGLoader will load DAGs from dag_dirs or other sources.
|
||||
Now only support load DAGs from local files.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@ -98,7 +99,7 @@ def _load_modules_from_file(
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
@ -106,7 +107,10 @@ def _process_modules(mods) -> List[DAG]:
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
if show_log:
|
||||
logger.info(
|
||||
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
|
||||
)
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
|
@ -6,9 +6,13 @@ from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
WithJsonSchema,
|
||||
field_validator,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
@ -255,9 +259,27 @@ class FlowCategory(str, Enum):
|
||||
raise ValueError(f"Invalid flow category value: {value}")
|
||||
|
||||
|
||||
_DAGModel = Annotated[
|
||||
DAG,
|
||||
WithJsonSchema(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_name": {"type": "string", "description": "Dummy task name"}
|
||||
},
|
||||
"description": "DAG model, not used in the serialization.",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class FlowPanel(BaseModel):
|
||||
"""Flow panel."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, json_encoders={DAG: lambda v: None}
|
||||
)
|
||||
|
||||
uid: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Flow panel uid",
|
||||
@ -277,7 +299,8 @@ class FlowPanel(BaseModel):
|
||||
description="Flow category",
|
||||
examples=[FlowCategory.COMMON, FlowCategory.CHAT_AGENT],
|
||||
)
|
||||
flow_data: FlowData = Field(..., description="Flow data")
|
||||
flow_data: Optional[FlowData] = Field(None, description="Flow data")
|
||||
flow_dag: Optional[_DAGModel] = Field(None, description="Flow DAG", exclude=True)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="Flow panel description",
|
||||
@ -305,6 +328,11 @@ class FlowPanel(BaseModel):
|
||||
description="Version of the flow panel",
|
||||
examples=["0.1.0", "0.2.0"],
|
||||
)
|
||||
define_type: Optional[str] = Field(
|
||||
"json",
|
||||
description="Define type of the flow panel",
|
||||
examples=["json", "python"],
|
||||
)
|
||||
editable: bool = Field(
|
||||
True,
|
||||
description="Whether the flow panel is editable",
|
||||
@ -344,7 +372,7 @@ class FlowPanel(BaseModel):
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self)
|
||||
return model_to_dict(self, exclude={"flow_dag"})
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
@ -356,7 +384,9 @@ class FlowFactory:
|
||||
|
||||
def build(self, flow_panel: FlowPanel) -> DAG:
|
||||
"""Build the flow."""
|
||||
flow_data = flow_panel.flow_data
|
||||
if not flow_panel.flow_data:
|
||||
raise ValueError("Flow data is required.")
|
||||
flow_data = cast(FlowData, flow_panel.flow_data)
|
||||
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource: Dict[str, ResourceMetadata] = {}
|
||||
@ -610,7 +640,10 @@ class FlowFactory:
|
||||
"""
|
||||
from dbgpt.util.module_utils import import_from_string
|
||||
|
||||
flow_data = flow_panel.flow_data
|
||||
if not flow_panel.flow_data:
|
||||
return
|
||||
|
||||
flow_data = cast(FlowData, flow_panel.flow_data)
|
||||
for node in flow_data.nodes:
|
||||
if node.data.is_operator:
|
||||
node_data = cast(ViewMetadata, node.data)
|
||||
@ -709,6 +742,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
Args:
|
||||
flow_panel (FlowPanel): The flow panel to fill.
|
||||
"""
|
||||
if not flow_panel.flow_data:
|
||||
return
|
||||
for node in flow_panel.flow_data.nodes:
|
||||
try:
|
||||
parameters_map = {}
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Base classes for operators that can be executed within a workflow."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
@ -265,7 +266,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
task_output = out_ctx.current_task_context.task_output
|
||||
if task_output.is_stream:
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
else:
|
||||
|
||||
async def _gen():
|
||||
yield task_output.output
|
||||
|
||||
return _gen()
|
||||
|
||||
def _blocking_call_stream(
|
||||
self,
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Common operators of AWEL."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
@ -171,6 +172,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
# Function that return the task name
|
||||
BranchTaskType = Union[str, Callable[[IN], str], Callable[[IN], Awaitable[str]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
@ -187,7 +190,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
branches: Optional[Dict[BranchFunc[IN], BranchTaskType]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a BranchDAGNode with a branching function.
|
||||
@ -208,6 +211,10 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not value.node_name:
|
||||
raise ValueError("branch node name must be set")
|
||||
branches[branch_function] = value.node_name
|
||||
elif callable(value):
|
||||
raise ValueError(
|
||||
"BranchTaskType must be str or BaseOperator on init"
|
||||
)
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
@ -234,14 +241,31 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
branch_name_tasks = []
|
||||
# branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
# branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
if callable(node_name):
|
||||
|
||||
async def map_node_name(func) -> str:
|
||||
input_context = await curr_task_ctx.task_input.map(func)
|
||||
task_name = input_context.parent_outputs[0].task_output.output
|
||||
return task_name
|
||||
|
||||
branch_name_tasks.append(map_node_name(node_name))
|
||||
|
||||
else:
|
||||
|
||||
async def _tmp_map_node_name(task_name: str) -> str:
|
||||
return task_name
|
||||
|
||||
branch_name_tasks.append(_tmp_map_node_name(node_name))
|
||||
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
branch_nodes: List[str] = await asyncio.gather(*branch_name_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
skip_node_names = []
|
||||
@ -258,7 +282,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
|
||||
"""Return branch logic based on input data."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -298,16 +298,24 @@ class ModelMessage(BaseModel):
|
||||
return str_msg
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["ModelMessage"]) -> str:
|
||||
def messages_to_string(
|
||||
messages: List["ModelMessage"],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
) -> str:
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages
|
||||
human_prefix (str): The human prefix
|
||||
ai_prefix (str): The ai prefix
|
||||
system_prefix (str): The system prefix
|
||||
|
||||
Returns:
|
||||
str: The str messages
|
||||
"""
|
||||
return _messages_to_str(messages)
|
||||
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
@ -1211,9 +1219,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
additional_kwargs=(
|
||||
ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {}
|
||||
),
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@ -6,7 +6,7 @@ import dataclasses
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
@ -19,6 +19,8 @@ from dbgpt.core.interface.storage import (
|
||||
)
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
|
||||
T = TypeVar("T", bound="BasePromptTemplate")
|
||||
|
||||
|
||||
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
@ -34,9 +36,9 @@ def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
_DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": lambda is_strict: formatter.format
|
||||
if is_strict
|
||||
else no_strict_formatter.format,
|
||||
"f-string": lambda is_strict: (
|
||||
formatter.format if is_strict else no_strict_formatter.format
|
||||
),
|
||||
"jinja2": lambda is_strict: _jinja2_formatter,
|
||||
}
|
||||
|
||||
@ -88,8 +90,8 @@ class PromptTemplate(BasePromptTemplate):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
cls: Type[T], template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> T:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
@ -116,14 +118,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls,
|
||||
cls: Type[T],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
response_format: Optional[str] = None,
|
||||
response_key: str = "response",
|
||||
template_is_strict: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatPromptTemplate:
|
||||
) -> T:
|
||||
"""Create a prompt template from a template string."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
template,
|
||||
|
4
dbgpt/experimental/__init__.py
Normal file
4
dbgpt/experimental/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""Experimental features for DB-GPT.
|
||||
|
||||
Warning: These features are experimental and may change or be removed in the future.
|
||||
"""
|
1
dbgpt/experimental/intent/__init__.py
Normal file
1
dbgpt/experimental/intent/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Intent detection module."""
|
186
dbgpt/experimental/intent/base.py
Normal file
186
dbgpt/experimental/intent/base.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""Base class for intent detection."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import (
|
||||
BaseOutputParser,
|
||||
LLMClient,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
PromptTemplate,
|
||||
)
|
||||
|
||||
_DEFAULT_PROMPT = """Please select the most matching intent from the intent definitions below based on the user's question,
|
||||
and return the complete intent information according to the requirements and output format.
|
||||
1. Strictly follow the given intent definition for output; do not create intents or slot attributes on your own. If an intent has no defined slots, the output should not include slots either.
|
||||
2. Extract slot attribute values from the user's input and historical dialogue information according to the intent definition. If the corresponding target information for the slot attribute cannot be obtained, the slot value should be empty.
|
||||
3. When extracting slot values, ensure to only obtain the effective value part. Do not include auxiliary descriptions or modifiers. Ensure that all slot attributes defined in the intent are output, regardless of whether values are obtained. If no values are found, output the slot name with an empty value.
|
||||
4. Ensure that if the user's question does not provide the content defined in the intent slots, the slot values must be empty. Do not fill slots with invalid information such as 'user did not provide'.
|
||||
5. If the information extracted from the user's question does not fully correspond to the matched intent slots, generate a new question to ask the user, prompting them to provide the missing slot data.
|
||||
|
||||
{response}
|
||||
|
||||
You can refer to the following examples:
|
||||
{example}
|
||||
|
||||
The known intent information is defined as follows:
|
||||
{intent_definitions}
|
||||
|
||||
Here are the known historical dialogue messages. If they are not relevant to the user's question, they can be ignored(Some times you can extract useful intent and slot information from the historical dialogue messages).
|
||||
{history}
|
||||
|
||||
User question: {user_input}
|
||||
""" # noqa
|
||||
|
||||
_DEFAULT_PROMPT_ZH = """从下面的意图定义中选择一个和用户问题最匹配的意图,并根据要求和输出格式返回意图完整信息。
|
||||
1. 严格根给出的意图定义输出,不要自行生成意图和槽位属性,意图没有定义槽位则输出也不应该包含槽位。
|
||||
2. 从用户输入和历史对话信息中提取意图定义中槽位属性的值,如果无法获取到槽位属性对应的目标信息,则槽位值输出空。
|
||||
3. 槽位值提取时请注意只获取有效值部分,不要填入辅助描述或定语确保意图定义的槽位属性不管是否获取到值,都要输出全部定义给出的槽位属性,没有找到值的输出槽位名和空值。
|
||||
4. 请确保如果用户问题中未提供意图槽位定义的内容,则槽位值必须为空,不要在槽位里填‘用户未提供’这类无效信息。
|
||||
5. 如果用户问题内容提取的信息和匹配到的意图槽位无法完全对应,则生成新的问题向用户提问,提示用户补充缺少的槽位数据。
|
||||
|
||||
{response}
|
||||
|
||||
可以参考下面的例子:
|
||||
{example}
|
||||
|
||||
已知的意图信息定义如下:
|
||||
{intent_definitions}
|
||||
|
||||
以下是已知的历史对话消息,如果和用户问题无关可以忽略(有时可以从历史对话消息中提取有用的意图和槽位信息)。
|
||||
{history}
|
||||
|
||||
用户问题:{user_input}
|
||||
""" # noqa
|
||||
|
||||
|
||||
class IntentDetectionResponse(BaseModel):
|
||||
"""Response schema for intent detection."""
|
||||
|
||||
intent: str = Field(
|
||||
...,
|
||||
description="The intent of user question.",
|
||||
)
|
||||
thought: str = Field(
|
||||
...,
|
||||
description="Logic and rationale for selecting the current application.",
|
||||
)
|
||||
task_name: str = Field(
|
||||
...,
|
||||
description="The task name of the intent.",
|
||||
)
|
||||
slots: Optional[dict] = Field(
|
||||
None,
|
||||
description="The slots of user question.",
|
||||
)
|
||||
user_input: str = Field(
|
||||
...,
|
||||
description="Instructions generated based on intent and slot.",
|
||||
)
|
||||
ask_user: Optional[str] = Field(
|
||||
None,
|
||||
description="Questions to users.",
|
||||
)
|
||||
|
||||
def has_empty_slot(self):
|
||||
"""Check if the response has empty slot."""
|
||||
if self.slots:
|
||||
for key, value in self.slots.items():
|
||||
if not value or len(value) <= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def to_response_format(cls) -> str:
|
||||
"""Get the response format."""
|
||||
schema_dict = {
|
||||
"intent": "[Intent placeholder]",
|
||||
"thought": "Your reasoning idea here.",
|
||||
"task_name": "[Task name of the intent]",
|
||||
"slots": {
|
||||
"Slot attribute 1 in the intention definition": "[Slot value 1]",
|
||||
"Slot attribute 2 in the intention definition": "[Slot value 2]",
|
||||
},
|
||||
"ask_user": "If you want the user to supplement the slot data, the problem"
|
||||
" is raised to the user, please use the same language as the user.",
|
||||
"user_input": "Complete instructions generated according to the intention "
|
||||
"and slot, please use the same language as the user.",
|
||||
}
|
||||
# How to integration the streaming json
|
||||
schema_str = json.dumps(schema_dict, indent=2, ensure_ascii=False)
|
||||
response_format = (
|
||||
f"Please output in the following JSON format: \n{schema_str}"
|
||||
f"\nMake sure the response is correct json and can be parsed by Python "
|
||||
f"json.loads."
|
||||
)
|
||||
return response_format
|
||||
|
||||
|
||||
class BaseIntentDetection(ABC):
|
||||
"""Base class for intent detection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_definitions: str,
|
||||
prompt_template: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
examples: Optional[str] = None,
|
||||
):
|
||||
"""Create a new intent detection instance."""
|
||||
self._intent_definitions = intent_definitions
|
||||
self._prompt_template = prompt_template
|
||||
self._response_format = response_format
|
||||
self._examples = examples
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Get the LLM client."""
|
||||
|
||||
@property
|
||||
def response_schema(self) -> Type[IntentDetectionResponse]:
|
||||
"""Return the response schema."""
|
||||
return IntentDetectionResponse
|
||||
|
||||
async def detect_intent(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model: Optional[str] = None,
|
||||
language: str = "en",
|
||||
) -> IntentDetectionResponse:
|
||||
"""Detect intent from messages."""
|
||||
default_prompt = _DEFAULT_PROMPT if language == "en" else _DEFAULT_PROMPT_ZH
|
||||
|
||||
models = await self.llm_client.models()
|
||||
if not models:
|
||||
raise Exception("No models available.")
|
||||
model = model or models[0].model
|
||||
history_messages = ModelMessage.messages_to_string(
|
||||
messages[:-1], human_prefix="user", ai_prefix="assistant"
|
||||
)
|
||||
|
||||
prompt_template = self._prompt_template or default_prompt
|
||||
|
||||
template: PromptTemplate = PromptTemplate.from_template(prompt_template)
|
||||
response_schema = self.response_schema
|
||||
response_format = self._response_format or response_schema.to_response_format()
|
||||
formatted_message = template.format(
|
||||
response=response_format,
|
||||
example=self._examples,
|
||||
intent_definitions=self._intent_definitions,
|
||||
history=history_messages,
|
||||
user_input=messages[-1].content,
|
||||
)
|
||||
model_messages = ModelMessage.build_human_message(formatted_message)
|
||||
model_request = ModelRequest.build_request(model, messages=[model_messages])
|
||||
model_output = await self.llm_client.generate(model_request)
|
||||
output_parser = BaseOutputParser()
|
||||
str_out = output_parser.parse_model_nostream_resp(
|
||||
model_output, "#########################"
|
||||
)
|
||||
json_out = output_parser.parse_prompt_response(str_out)
|
||||
dict_out = json.loads(json_out)
|
||||
return response_schema.model_validate(dict_out)
|
92
dbgpt/experimental/intent/operators.py
Normal file
92
dbgpt/experimental/intent/operators.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Operators for intent detection."""
|
||||
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from dbgpt.core import ModelMessage, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core.awel import BranchFunc, BranchOperator, BranchTaskType, MapOperator
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from .base import BaseIntentDetection, IntentDetectionResponse
|
||||
|
||||
|
||||
class IntentDetectionOperator(
|
||||
MixinLLMOperator, BaseIntentDetection, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
"""The intent detection operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_definitions: str,
|
||||
prompt_template: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
examples: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create the intent detection operator."""
|
||||
MixinLLMOperator.__init__(self)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
BaseIntentDetection.__init__(
|
||||
self,
|
||||
intent_definitions=intent_definitions,
|
||||
prompt_template=prompt_template,
|
||||
response_format=response_format,
|
||||
examples=examples,
|
||||
)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Detect the intent.
|
||||
|
||||
Merge the intent detection result into the context.
|
||||
"""
|
||||
language = "en"
|
||||
if self.system_app:
|
||||
language = self.system_app.config.get_current_lang()
|
||||
messages = self.parse_messages(input_value)
|
||||
ic = await self.detect_intent(
|
||||
messages,
|
||||
input_value.model,
|
||||
language=language,
|
||||
)
|
||||
if not input_value.context:
|
||||
input_value.context = ModelRequestContext()
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
input_value.context.extra["intent_detection"] = ic
|
||||
return input_value
|
||||
|
||||
def parse_messages(self, request: ModelRequest) -> List[ModelMessage]:
|
||||
"""Parse the messages from the request."""
|
||||
return request.get_messages()
|
||||
|
||||
|
||||
class IntentDetectionBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
"""The intent detection branch operator."""
|
||||
|
||||
def __init__(self, end_task_name: str, **kwargs):
|
||||
"""Create the intent detection branch operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._end_task_name = end_task_name
|
||||
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], BranchTaskType]:
|
||||
"""Branch the intent detection result to different tasks."""
|
||||
download_task_names = set(task.node_name for task in self.downstream) # noqa
|
||||
branch_func_map = {}
|
||||
for task_name in download_task_names:
|
||||
|
||||
def check(r: ModelRequest, outer_task_name=task_name):
|
||||
if not r.context or not r.context.extra:
|
||||
return False
|
||||
ic_result = r.context.extra.get("intent_detection")
|
||||
if not ic_result:
|
||||
return False
|
||||
ic: IntentDetectionResponse = cast(IntentDetectionResponse, ic_result)
|
||||
if ic.has_empty_slot():
|
||||
return self._end_task_name == outer_task_name
|
||||
else:
|
||||
return outer_task_name == ic.task_name
|
||||
|
||||
branch_func_map[check] = task_name
|
||||
|
||||
return branch_func_map # type: ignore
|
@ -1,6 +1,7 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union
|
||||
@ -33,6 +34,12 @@ class ServeEntity(Model):
|
||||
source = Column(String(64), nullable=True, comment="Flow source")
|
||||
source_url = Column(String(512), nullable=True, comment="Flow source url")
|
||||
version = Column(String(32), nullable=True, comment="Flow version")
|
||||
define_type = Column(
|
||||
String(32),
|
||||
default="json",
|
||||
nullable=True,
|
||||
comment="Flow define type(json or python)",
|
||||
)
|
||||
editable = Column(
|
||||
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
|
||||
)
|
||||
@ -103,6 +110,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"source": request_dict.get("source"),
|
||||
"source_url": request_dict.get("source_url"),
|
||||
"version": request_dict.get("version"),
|
||||
"define_type": request_dict.get("define_type"),
|
||||
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
|
||||
"description": request_dict.get("description"),
|
||||
"user_name": request_dict.get("user_name"),
|
||||
@ -133,6 +141,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
source=entity.source,
|
||||
source_url=entity.source_url,
|
||||
version=entity.version,
|
||||
define_type=entity.define_type,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
description=entity.description,
|
||||
user_name=entity.user_name,
|
||||
@ -165,6 +174,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
source_url=entity.source_url,
|
||||
version=entity.version,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
define_type=entity.define_type,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
@ -203,6 +213,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
if update_request.version:
|
||||
entry.version = update_request.version
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
if update_request.define_type:
|
||||
entry.define_type = update_request.define_type
|
||||
if update_request.user_name:
|
||||
entry.user_name = update_request.user_name
|
||||
if update_request.sys_code:
|
||||
|
@ -138,7 +138,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
try:
|
||||
# Build DAG from request
|
||||
dag = self._flow_factory.build(request)
|
||||
if request.define_type == "json":
|
||||
dag = self._flow_factory.build(request)
|
||||
else:
|
||||
dag = request.flow_dag
|
||||
request.dag_id = dag.dag_id
|
||||
# Save DAG to storage
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
@ -149,7 +152,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
request.dag_id = ""
|
||||
return self.dao.create(request)
|
||||
else:
|
||||
raise e
|
||||
raise ValueError(
|
||||
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
||||
) from e
|
||||
res = self.dao.create(request)
|
||||
|
||||
state = request.state
|
||||
@ -193,6 +198,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
entities = self.dao.get_list({})
|
||||
for entity in entities:
|
||||
try:
|
||||
if entity.define_type != "json":
|
||||
continue
|
||||
dag = self._flow_factory.build(entity)
|
||||
if entity.state in [State.DEPLOYED, State.RUNNING] or (
|
||||
entity.version == "0.1.0" and entity.state == State.INITIALIZING
|
||||
@ -213,7 +220,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
try:
|
||||
self._flow_factory.pre_load_requirements(flow)
|
||||
if flow.define_type == "json":
|
||||
self._flow_factory.pre_load_requirements(flow)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Pre load requirements for DAG({flow.name}) from "
|
||||
@ -225,6 +233,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
try:
|
||||
if flow.define_type == "python" and flow.flow_dag is None:
|
||||
continue
|
||||
# Set state to DEPLOYED
|
||||
flow.state = State.DEPLOYED
|
||||
exist_inst = self.get({"name": flow.name})
|
||||
@ -260,7 +270,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
new_state = request.state
|
||||
try:
|
||||
# Try to build the dag from the request
|
||||
dag = self._flow_factory.build(request)
|
||||
if request.define_type == "json":
|
||||
dag = self._flow_factory.build(request)
|
||||
else:
|
||||
dag = request.flow_dag
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
except Exception as e:
|
||||
if save_failed_flow:
|
||||
@ -295,6 +308,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow detail {request.uid} not found"
|
||||
)
|
||||
update_obj.flow_dag = request.flow_dag
|
||||
return self.create_and_save_dag(update_obj)
|
||||
except Exception as e:
|
||||
if old_data and old_data.state == State.RUNNING:
|
||||
|
@ -10,6 +10,7 @@ import tomlkit
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
from dbgpt.util.dbgpts.base import (
|
||||
DBGPTS_METADATA_FILE,
|
||||
@ -77,7 +78,7 @@ class BasePackage(BaseModel):
|
||||
values: Dict[str, Any],
|
||||
expected_cls: Type[T],
|
||||
predicates: Optional[List[Callable[..., bool]]] = None,
|
||||
) -> Tuple[List[Type[T]], List[Any]]:
|
||||
) -> Tuple[List[Type[T]], List[Any], List[Any]]:
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
@ -101,7 +102,7 @@ class BasePackage(BaseModel):
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls, all_predicate_results
|
||||
return module_cls, all_predicate_results, mods
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
@ -113,6 +114,24 @@ class FlowPackage(BasePackage):
|
||||
) -> "FlowPackage":
|
||||
if values["definition_type"] == "json":
|
||||
return FlowJsonPackage.build_from(values, ext_dict)
|
||||
return FlowPythonPackage.build_from(values, ext_dict)
|
||||
|
||||
|
||||
class FlowPythonPackage(FlowPackage):
|
||||
dag: DAG = Field(..., description="The DAG of the package")
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.core.awel.dag.loader import _process_modules
|
||||
|
||||
_, _, mods = cls.load_module_class(values, DAG)
|
||||
|
||||
dags = _process_modules(mods, show_log=False)
|
||||
if not dags:
|
||||
raise ValueError("No DAGs found in the package")
|
||||
if len(dags) > 1:
|
||||
raise ValueError("Only support one DAG in the package")
|
||||
values["dag"] = dags[0]
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@ -144,7 +163,7 @@ class OperatorPackage(BasePackage):
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
|
||||
values["operators"], _ = cls.load_module_class(values, BaseOperator)
|
||||
values["operators"], _, _ = cls.load_module_class(values, BaseOperator)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@ -159,7 +178,7 @@ class AgentPackage(BasePackage):
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.agent import ConversableAgent
|
||||
|
||||
values["agents"], _ = cls.load_module_class(values, ConversableAgent)
|
||||
values["agents"], _, _ = cls.load_module_class(values, ConversableAgent)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@ -190,7 +209,7 @@ class ResourcePackage(BasePackage):
|
||||
else:
|
||||
return False
|
||||
|
||||
_, predicted_cls = cls.load_module_class(values, Resource, [_predicate])
|
||||
_, predicted_cls, _ = cls.load_module_class(values, Resource, [_predicate])
|
||||
resource_instances = []
|
||||
resources = []
|
||||
for o in predicted_cls:
|
||||
@ -353,7 +372,7 @@ class DBGPTsLoader(BaseComponent):
|
||||
for package in self._packages.values():
|
||||
if package.package_type != "flow":
|
||||
continue
|
||||
package = cast(FlowJsonPackage, package)
|
||||
package = cast(FlowPackage, package)
|
||||
dict_value = {
|
||||
"name": package.name,
|
||||
"label": package.label,
|
||||
@ -361,8 +380,24 @@ class DBGPTsLoader(BaseComponent):
|
||||
"editable": False,
|
||||
"description": package.description,
|
||||
"source": package.repo,
|
||||
"flow_data": package.read_definition_json(),
|
||||
"define_type": "json",
|
||||
}
|
||||
if isinstance(package, FlowJsonPackage):
|
||||
dict_value["flow_data"] = package.read_definition_json()
|
||||
elif isinstance(package, FlowPythonPackage):
|
||||
dict_value["flow_data"] = {
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
"viewport": {
|
||||
"x": 213,
|
||||
"y": 269,
|
||||
"zoom": 0,
|
||||
},
|
||||
}
|
||||
dict_value["flow_dag"] = package.dag
|
||||
dict_value["define_type"] = "python"
|
||||
else:
|
||||
raise ValueError(f"Unsupported package type: {package}")
|
||||
panels.append(FlowPanel(**dict_value))
|
||||
return panels
|
||||
|
||||
|
@ -97,9 +97,7 @@ def _create_flow_template(
|
||||
if definition_type == "json":
|
||||
_write_flow_define_json_file(working_directory, name, mod_name)
|
||||
else:
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: {dbgpts_type}"
|
||||
)
|
||||
_write_flow_define_python_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_operator_template(
|
||||
@ -222,6 +220,16 @@ def _write_flow_define_json_file(working_directory: str, name: str, mod_name: st
|
||||
print("Please write your flow json to the file: ", def_file)
|
||||
|
||||
|
||||
def _write_flow_define_python_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the flow define python file"""
|
||||
|
||||
init_file = Path(working_directory) / name / mod_name / "__init__.py"
|
||||
content = ""
|
||||
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} flow package"""\n{content}')
|
||||
|
||||
|
||||
def _write_operator_init_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the operator __init__.py file"""
|
||||
|
||||
|
6
package-lock.json
generated
6
package-lock.json
generated
@ -1,6 +0,0 @@
|
||||
{
|
||||
"name": "DB-GPT",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
Loading…
Reference in New Issue
Block a user