diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index 86f7ed740..237091b1e 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -282,6 +282,7 @@ CREATE TABLE `dbgpt_serve_flow` ( `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', + `define_type` varchar(32) null comment 'Flow define type(json or python)', `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', PRIMARY KEY (`id`), @@ -340,6 +341,28 @@ CREATE TABLE `gpts_app_detail` ( UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- For deploy model cluster of DB-GPT(StorageModelRegistry) +CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `model_name` varchar(128) NOT NULL COMMENT 'Model name', + `host` varchar(128) NOT NULL COMMENT 'Host of the model', + `port` int(11) NOT NULL COMMENT 'Port of the model', + `weight` float DEFAULT 1.0 COMMENT 'Weight of the model', + `check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model', + `healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy', + `enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled', + `prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance', + `last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; + + CREATE DATABASE IF NOT EXISTS EXAMPLE_1; use EXAMPLE_1; diff --git a/assets/schema/upgrade/v_0_5_6/upgrade_to_v0.5.6.sql b/assets/schema/upgrade/v0_5_6/upgrade_to_v0.5.6.sql similarity index 100% rename from assets/schema/upgrade/v_0_5_6/upgrade_to_v0.5.6.sql rename to assets/schema/upgrade/v0_5_6/upgrade_to_v0.5.6.sql diff --git a/assets/schema/upgrade/v_0_5_6/v0.5.5.sql b/assets/schema/upgrade/v0_5_6/v0.5.5.sql similarity index 100% rename from assets/schema/upgrade/v_0_5_6/v0.5.5.sql rename to assets/schema/upgrade/v0_5_6/v0.5.5.sql diff --git a/assets/schema/upgrade/v_0_5_7/upgrade_to_v0.5.7.sql b/assets/schema/upgrade/v0_5_7/upgrade_to_v0.5.7.sql similarity index 100% rename from assets/schema/upgrade/v_0_5_7/upgrade_to_v0.5.7.sql rename to assets/schema/upgrade/v0_5_7/upgrade_to_v0.5.7.sql diff --git a/assets/schema/upgrade/v_0_5_7/v0.5.6.sql b/assets/schema/upgrade/v0_5_7/v0.5.6.sql similarity index 100% rename from assets/schema/upgrade/v_0_5_7/v0.5.6.sql rename to assets/schema/upgrade/v0_5_7/v0.5.6.sql diff --git a/assets/schema/upgrade/v0_5_9/upgrade_to_v0.5.9.sql b/assets/schema/upgrade/v0_5_9/upgrade_to_v0.5.9.sql new file mode 100644 index 000000000..8ee337bf8 --- /dev/null +++ b/assets/schema/upgrade/v0_5_9/upgrade_to_v0.5.9.sql @@ -0,0 +1,22 @@ +USE dbgpt; + +-- For deploy model cluster of DB-GPT(StorageModelRegistry) +CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `model_name` varchar(128) NOT NULL COMMENT 'Model name', + `host` varchar(128) NOT NULL COMMENT 'Host of the model', + `port` int(11) NOT NULL COMMENT 'Port of the model', + `weight` float DEFAULT 1.0 COMMENT 'Weight of the model', + `check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model', + `healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy', + `enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled', + `prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance', + `last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; + diff --git a/assets/schema/upgrade/v0_5_9/v0.5.8.sql b/assets/schema/upgrade/v0_5_9/v0.5.8.sql new file mode 100644 index 000000000..d810d2369 --- /dev/null +++ b/assets/schema/upgrade/v0_5_9/v0.5.8.sql @@ -0,0 +1,396 @@ +-- Full SQL of v0.5.8, please not modify this file(It must be same as the file in the release package) + +CREATE +DATABASE IF NOT EXISTS dbgpt; +use dbgpt; + +-- For alembic migration tool +CREATE TABLE IF NOT EXISTS `alembic_version` +( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) +) DEFAULT CHARSET=utf8mb4 ; + +CREATE TABLE IF NOT EXISTS `knowledge_space` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `context` TEXT DEFAULT NULL COMMENT 'context argument', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; + +CREATE TABLE IF NOT EXISTS `knowledge_document` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; + +CREATE TABLE IF NOT EXISTS `document_chunk` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; + + + +CREATE TABLE IF NOT EXISTS `connect_config` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `db_type` varchar(255) NOT NULL COMMENT 'db type', + `db_name` varchar(255) NOT NULL COMMENT 'db name', + `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', + `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', + `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', + `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', + `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', + `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_db` (`db_name`), + KEY `idx_q_db_type` (`db_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; + +CREATE TABLE IF NOT EXISTS `chat_history` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', + `summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', + `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `conv_uid` (`conv_uid`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; + +CREATE TABLE IF NOT EXISTS `chat_history_message` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `index` int NOT NULL COMMENT 'Message index', + `round_index` int NOT NULL COMMENT 'Round of conversation', + `message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `message_uid_index` (`conv_uid`, `index`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + +CREATE TABLE IF NOT EXISTS `chat_feed_back` +( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', + `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', + `score` int(1) DEFAULT NULL COMMENT 'Score of user', + `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', + `question` longtext DEFAULT NULL COMMENT 'User question', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', + `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; + + +CREATE TABLE IF NOT EXISTS `my_plugin` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', + `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; + +CREATE TABLE IF NOT EXISTS `plugin_hub` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', + `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; + + +CREATE TABLE IF NOT EXISTS `prompt_manage` +( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + CREATE TABLE IF NOT EXISTS `gpts_conversations` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `user_goal` text NOT NULL COMMENT 'User''s goals content', + `gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name', + `state` varchar(255) DEFAULT NULL COMMENT 'The gpts state', + `max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round', + `auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NULL COMMENT 'agent team work mode', + + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_conversations` (`conv_id`), + KEY `idx_gpts_name` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations"; + +CREATE TABLE IF NOT EXISTS `gpts_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `resource_db` text COMMENT 'List of structured database names contained in the current gpts', + `resource_internet` text COMMENT 'Is it possible to retrieve information from the internet', + `resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts', + `gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts', + `gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts', + `language` varchar(100) DEFAULT NULL COMMENT 'gpts language', + `user_code` varchar(255) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance"; + +CREATE TABLE `gpts_messages` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn', + `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', + `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', + `rounds` int(11) NOT NULL COMMENT 'dialogue turns', + `content` text COMMENT 'Content of the speech', + `current_goal` text COMMENT 'The target corresponding to the current message', + `context` text COMMENT 'Current conversation context', + `review_info` text COMMENT 'Current conversation review info', + `action_report` text COMMENT 'Current conversation action report', + `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message"; + + +CREATE TABLE `gpts_plans` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sub_task_num` int(11) NOT NULL COMMENT 'Subtask number', + `sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title', + `sub_task_content` text NOT NULL COMMENT 'subtask content', + `sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks', + `resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name', + `rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3', + `agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents', + `retry_times` int(11) DEFAULT NULL COMMENT 'number of retries', + `max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries', + `state` varchar(255) DEFAULT NULL COMMENT 'subtask status', + `result` longtext COMMENT 'subtask result', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan"; + +-- dbgpt.dbgpt_serve_flow definition +CREATE TABLE `dbgpt_serve_flow` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `uid` varchar(128) NOT NULL COMMENT 'Unique id', + `dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id', + `name` varchar(128) DEFAULT NULL COMMENT 'Flow name', + `flow_data` text COMMENT 'Flow data, JSON format', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time', + `flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category', + `description` varchar(512) DEFAULT NULL COMMENT 'Flow description', + `state` varchar(32) DEFAULT NULL COMMENT 'Flow state', + `error_message` varchar(512) NULL comment 'Error message', + `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', + `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', + `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', + `define_type` varchar(32) null comment 'Flow define type(json or python)', + `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', + `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`), + KEY `ix_dbgpt_serve_flow_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`), + KEY `ix_dbgpt_serve_flow_user_name` (`user_name`), + KEY `ix_dbgpt_serve_flow_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.gpts_app definition +CREATE TABLE `gpts_app` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `language` varchar(100) NOT NULL COMMENT 'gpts language', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app` (`app_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE TABLE `gpts_app_collection` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `user_code` int(11) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NOT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_app_code` (`app_code`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections"; + +-- dbgpt.gpts_app_detail definition +CREATE TABLE `gpts_app_detail` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `agent_name` varchar(255) NOT NULL COMMENT ' Agent name', + `node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id', + `resources` text COMMENT 'Agent bind resource', + `prompt_template` text COMMENT 'Agent bind template', + `llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy', + `llm_strategy_value` text COMMENT 'Agent use llm strategy value', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE +DATABASE IF NOT EXISTS EXAMPLE_1; +use EXAMPLE_1; +CREATE TABLE IF NOT EXISTS `users` +( + `id` int NOT NULL AUTO_INCREMENT, + `username` varchar(50) NOT NULL COMMENT '用户名', + `password` varchar(50) NOT NULL COMMENT '密码', + `email` varchar(50) NOT NULL COMMENT '邮箱', + `phone` varchar(20) DEFAULT NULL COMMENT '电话', + PRIMARY KEY (`id`), + KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; + +INSERT INTO users (username, password, email, phone) +VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index 4d6353bcc..0749ccdf0 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -1,11 +1,14 @@ """Import all models to make sure they are registered with SQLAlchemy. """ + from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity +from dbgpt.model.cluster.registry_impl.db_storage import ModelInstanceEntity from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity +from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity from dbgpt.storage.chat_history.chat_history_db import ( @@ -24,4 +27,6 @@ _MODELS = [ ConnectConfigEntity, ChatHistoryEntity, ChatHistoryMessageEntity, + ModelInstanceEntity, + FlowServeEntity, ] diff --git a/dbgpt/model/cli.py b/dbgpt/model/cli.py index 24a367fb1..90731ac37 100644 --- a/dbgpt/model/cli.py +++ b/dbgpt/model/cli.py @@ -26,6 +26,7 @@ from dbgpt.util.parameter_utils import ( build_lazy_click_command, ) +# Your can set environment variable CONTROLLER_ADDRESS to set the default address MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000" logger = logging.getLogger("dbgpt_cli") diff --git a/dbgpt/model/cluster/apiserver/tests/test_api.py b/dbgpt/model/cluster/apiserver/tests/test_api.py index d9001ed44..64d880962 100644 --- a/dbgpt/model/cluster/apiserver/tests/test_api.py +++ b/dbgpt/model/cluster/apiserver/tests/test_api.py @@ -1,22 +1,11 @@ -import importlib.metadata as metadata - import pytest import pytest_asyncio -from aioresponses import aioresponses -from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from httpx import AsyncClient, HTTPError +from httpx import ASGITransport, AsyncClient, HTTPError from dbgpt.component import SystemApp from dbgpt.model.cluster.apiserver.api import ( - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, ModelList, - UsageInfo, api_settings, initialize_apiserver, ) @@ -56,12 +45,13 @@ async def client(request, system_app: SystemApp): if api_settings: # Clear global api keys api_settings.api_keys = [] - async with AsyncClient(app=app, base_url="http://test", headers=headers) as client: + async with AsyncClient( + transport=ASGITransport(app), base_url="http://test", headers=headers + ) as client: async with _new_cluster(**param) as cluster: worker_manager, model_registry = cluster system_app.register(_DefaultWorkerManagerFactory, worker_manager) system_app.register_instance(model_registry) - # print(f"Instances {model_registry.registry}") initialize_apiserver(None, app, system_app, api_keys=api_keys) yield client @@ -113,7 +103,11 @@ async def test_chat_completions(client: AsyncClient, expected_messages): "Hello world.", "abc", ), - ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ( + {"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, + "你好,我是张三。", + "abc", + ), ], indirect=["client"], ) @@ -160,7 +154,11 @@ async def test_chat_completions_with_openai_lib_async_no_stream( "Hello world.", "abc", ), - ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ( + {"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, + "你好,我是张三。", + "abc", + ), ], indirect=["client"], ) diff --git a/dbgpt/model/cluster/controller/controller.py b/dbgpt/model/cluster/controller/controller.py index 64311cd96..a28e4ccb7 100644 --- a/dbgpt/model/cluster/controller/controller.py +++ b/dbgpt/model/cluster/controller/controller.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import List +from typing import List, Literal, Optional from fastapi import APIRouter @@ -8,6 +8,7 @@ from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.model.base import ModelInstance from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from dbgpt.model.parameter import ModelControllerParameters +from dbgpt.util.api_utils import APIMixin from dbgpt.util.api_utils import _api_remote as api_remote from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote from dbgpt.util.fastapi import create_app @@ -46,9 +47,7 @@ class BaseModelController(BaseComponent, ABC): class LocalModelController(BaseModelController): - def __init__(self, registry: ModelRegistry = None) -> None: - if not registry: - registry = EmbeddedModelRegistry() + def __init__(self, registry: ModelRegistry) -> None: self.registry = registry self.deployment = None @@ -75,9 +74,25 @@ class LocalModelController(BaseModelController): return await self.registry.send_heartbeat(instance) -class _RemoteModelController(BaseModelController): - def __init__(self, base_url: str) -> None: - self.base_url = base_url +class _RemoteModelController(APIMixin, BaseModelController): + def __init__( + self, + urls: str, + health_check_interval_secs: int = 5, + health_check_timeout_secs: int = 30, + check_health: bool = True, + choice_type: Literal["latest_first", "random"] = "latest_first", + ) -> None: + APIMixin.__init__( + self, + urls=urls, + health_check_path="/api/health", + health_check_interval_secs=health_check_interval_secs, + health_check_timeout_secs=health_check_timeout_secs, + check_health=check_health, + choice_type=choice_type, + ) + BaseModelController.__init__(self) @api_remote(path="/api/controller/models", method="POST") async def register_instance(self, instance: ModelInstance) -> bool: @@ -139,13 +154,19 @@ controller = ModelControllerAdapter() def initialize_controller( - app=None, remote_controller_addr: str = None, host: str = None, port: int = None + app=None, + remote_controller_addr: str = None, + host: str = None, + port: int = None, + registry: Optional[ModelRegistry] = None, ): global controller if remote_controller_addr: controller.backend = _RemoteModelController(remote_controller_addr) else: - controller.backend = LocalModelController() + if not registry: + registry = EmbeddedModelRegistry() + controller.backend = LocalModelController(registry=registry) if app: app.include_router(router, prefix="/api", tags=["Model"]) @@ -158,6 +179,12 @@ def initialize_controller( uvicorn.run(app, host=host, port=port, log_level="info") +@router.get("/health") +async def api_health_check(): + """Health check API.""" + return {"status": "ok"} + + @router.post("/controller/models") async def api_register_instance(request: ModelInstance): return await controller.register_instance(request) @@ -179,6 +206,87 @@ async def api_model_heartbeat(request: ModelInstance): return await controller.send_heartbeat(request) +def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry: + """Create a model registry based on the controller parameters. + + Registry will store the metadata of all model instances, it will be a high + availability service for model instances if you use a database registry now. Also, + we can implement more registry types in the future. + """ + registry_type = controller_params.registry_type.strip() + if controller_params.registry_type == "embedded": + return EmbeddedModelRegistry( + heartbeat_interval_secs=controller_params.heartbeat_interval_secs, + heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs, + ) + elif controller_params.registry_type == "database": + from urllib.parse import quote + from urllib.parse import quote_plus as urlquote + + from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry + + try_to_create_db = False + + if controller_params.registry_db_type == "mysql": + db_name = controller_params.registry_db_name + db_host = controller_params.registry_db_host + db_port = controller_params.registry_db_port + db_user = controller_params.registry_db_user + db_password = controller_params.registry_db_password + if not db_name: + raise ValueError( + "Registry DB name is required when using MySQL registry." + ) + if not db_host: + raise ValueError( + "Registry DB host is required when using MySQL registry." + ) + if not db_port: + raise ValueError( + "Registry DB port is required when using MySQL registry." + ) + if not db_user: + raise ValueError( + "Registry DB user is required when using MySQL registry." + ) + if not db_password: + raise ValueError( + "Registry DB password is required when using MySQL registry." + ) + db_url = ( + f"mysql+pymysql://{quote(db_user)}:" + f"{urlquote(db_password)}@" + f"{db_host}:" + f"{str(db_port)}/" + f"{db_name}?charset=utf8mb4" + ) + elif controller_params.registry_db_type == "sqlite": + db_name = controller_params.registry_db_name + if not db_name: + raise ValueError( + "Registry DB name is required when using SQLite registry." + ) + db_url = f"sqlite:///{db_name}" + try_to_create_db = True + else: + raise ValueError( + f"Unsupported registry DB type: {controller_params.registry_db_type}" + ) + + registry = StorageModelRegistry.from_url( + db_url, + db_name, + pool_size=controller_params.registry_db_pool_size, + max_overflow=controller_params.registry_db_max_overflow, + try_to_create_db=try_to_create_db, + heartbeat_interval_secs=controller_params.heartbeat_interval_secs, + heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs, + ) + return registry + else: + raise ValueError(f"Unsupported registry type: {registry_type}") + + def run_model_controller(): parser = EnvArgumentParser() env_prefix = "controller_" @@ -192,8 +300,11 @@ def run_model_controller(): logging_level=controller_params.log_level, logger_filename=controller_params.log_file, ) + registry = _create_registry(controller_params) - initialize_controller(host=controller_params.host, port=controller_params.port) + initialize_controller( + host=controller_params.host, port=controller_params.port, registry=registry + ) if __name__ == "__main__": diff --git a/tools/__init__.py b/dbgpt/model/cluster/registry_impl/__init__.py similarity index 100% rename from tools/__init__.py rename to dbgpt/model/cluster/registry_impl/__init__.py diff --git a/dbgpt/model/cluster/registry_impl/db_storage.py b/dbgpt/model/cluster/registry_impl/db_storage.py new file mode 100644 index 000000000..c10f05738 --- /dev/null +++ b/dbgpt/model/cluster/registry_impl/db_storage.py @@ -0,0 +1,116 @@ +from datetime import datetime + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Float, + Integer, + String, + UniqueConstraint, +) +from sqlalchemy.orm import Session + +from dbgpt.core.interface.storage import ResourceIdentifier, StorageItemAdapter +from dbgpt.storage.metadata import Model + +from .storage import ModelInstanceStorageItem + + +class ModelInstanceEntity(Model): + """Model instance entity. + + Use database as the registry, here is the table schema of the model instance. + """ + + __tablename__ = "dbgpt_cluster_registry_instance" + __table_args__ = ( + UniqueConstraint( + "model_name", + "host", + "port", + "sys_code", + name="uk_model_instance", + ), + ) + id = Column(Integer, primary_key=True, comment="Auto increment id") + model_name = Column(String(128), nullable=False, comment="Model name") + host = Column(String(128), nullable=False, comment="Host of the model") + port = Column(Integer, nullable=False, comment="Port of the model") + weight = Column(Float, nullable=True, default=1.0, comment="Weight of the model") + check_healthy = Column( + Boolean, + nullable=True, + default=True, + comment="Whether to check the health of the model", + ) + healthy = Column( + Boolean, nullable=True, default=False, comment="Whether the model is healthy" + ) + enabled = Column( + Boolean, nullable=True, default=True, comment="Whether the model is enabled" + ) + prompt_template = Column( + String(128), + nullable=True, + comment="Prompt template for the model instance", + ) + last_heartbeat = Column( + DateTime, + nullable=True, + comment="Last heartbeat time of the model instance", + ) + user_name = Column(String(128), nullable=True, comment="User name") + sys_code = Column(String(128), nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + +class ModelInstanceItemAdapter( + StorageItemAdapter[ModelInstanceStorageItem, ModelInstanceEntity] +): + def to_storage_format(self, item: ModelInstanceStorageItem) -> ModelInstanceEntity: + return ModelInstanceEntity( + model_name=item.model_name, + host=item.host, + port=item.port, + weight=item.weight, + check_healthy=item.check_healthy, + healthy=item.healthy, + enabled=item.enabled, + prompt_template=item.prompt_template, + last_heartbeat=item.last_heartbeat, + # user_name=item.user_name, + # sys_code=item.sys_code, + ) + + def from_storage_format( + self, model: ModelInstanceEntity + ) -> ModelInstanceStorageItem: + return ModelInstanceStorageItem( + model_name=model.model_name, + host=model.host, + port=model.port, + weight=model.weight, + check_healthy=model.check_healthy, + healthy=model.healthy, + enabled=model.enabled, + prompt_template=model.prompt_template, + last_heartbeat=model.last_heartbeat, + ) + + def get_query_for_identifier( + self, + storage_format: ModelInstanceEntity, + resource_id: ResourceIdentifier, + **kwargs, + ): + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + query_obj = session.query(ModelInstanceEntity) + for key, value in resource_id.to_dict().items(): + if value is None: + continue + query_obj = query_obj.filter(getattr(ModelInstanceEntity, key) == value) + return query_obj diff --git a/dbgpt/model/cluster/registry_impl/storage.py b/dbgpt/model/cluster/registry_impl/storage.py new file mode 100644 index 000000000..6cafed7cc --- /dev/null +++ b/dbgpt/model/cluster/registry_impl/storage.py @@ -0,0 +1,374 @@ +import threading +import time +from concurrent.futures import Executor, ThreadPoolExecutor +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple + +from dbgpt.component import SystemApp +from dbgpt.core.interface.storage import ( + QuerySpec, + ResourceIdentifier, + StorageInterface, + StorageItem, +) +from dbgpt.util.executor_utils import blocking_func_to_async + +from ...base import ModelInstance +from ..registry import ModelRegistry + + +@dataclass +class ModelInstanceIdentifier(ResourceIdentifier): + identifier_split: str = field(default="___$$$$___", init=False) + model_name: str + host: str + port: int + + def __post_init__(self): + """Post init method.""" + if self.model_name is None: + raise ValueError("model_name is required.") + if self.host is None: + raise ValueError("host is required.") + if self.port is None: + raise ValueError("port is required.") + + if any( + self.identifier_split in key + for key in [self.model_name, self.host, str(self.port)] + if key is not None + ): + raise ValueError( + f"identifier_split {self.identifier_split} is not allowed in " + f"model_name, host, port." + ) + + @property + def str_identifier(self) -> str: + """Return the string identifier of the identifier.""" + return self.identifier_split.join( + key + for key in [ + self.model_name, + self.host, + str(self.port), + ] + if key is not None + ) + + def to_dict(self) -> Dict: + """Convert the identifier to a dict. + + Returns: + Dict: The dict of the identifier. + """ + return { + "model_name": self.model_name, + "host": self.host, + "port": self.port, + } + + +@dataclass +class ModelInstanceStorageItem(StorageItem): + + model_name: str + host: str + port: int + weight: Optional[float] = 1.0 + check_healthy: Optional[bool] = True + healthy: Optional[bool] = False + enabled: Optional[bool] = True + prompt_template: Optional[str] = None + last_heartbeat: Optional[datetime] = None + _identifier: ModelInstanceIdentifier = field(init=False) + + def __post_init__(self): + """Post init method.""" + # Convert last_heartbeat to datetime if it's a timestamp + if isinstance(self.last_heartbeat, (int, float)): + self.last_heartbeat = datetime.fromtimestamp(self.last_heartbeat) + + self._identifier = ModelInstanceIdentifier( + model_name=self.model_name, + host=self.host, + port=self.port, + ) + + @property + def identifier(self) -> ModelInstanceIdentifier: + return self._identifier + + def merge(self, other: "StorageItem") -> None: + if not isinstance(other, ModelInstanceStorageItem): + raise ValueError(f"Cannot merge with {type(other)}") + self.from_object(other) + + def to_dict(self) -> Dict: + last_heartbeat = self.last_heartbeat.timestamp() + return { + "model_name": self.model_name, + "host": self.host, + "port": self.port, + "weight": self.weight, + "check_healthy": self.check_healthy, + "healthy": self.healthy, + "enabled": self.enabled, + "prompt_template": self.prompt_template, + "last_heartbeat": last_heartbeat, + } + + def from_object(self, item: "ModelInstanceStorageItem") -> None: + """Build the item from another item.""" + self.model_name = item.model_name + self.host = item.host + self.port = item.port + self.weight = item.weight + self.check_healthy = item.check_healthy + self.healthy = item.healthy + self.enabled = item.enabled + self.prompt_template = item.prompt_template + self.last_heartbeat = item.last_heartbeat + + @classmethod + def from_model_instance(cls, instance: ModelInstance) -> "ModelInstanceStorageItem": + return cls( + model_name=instance.model_name, + host=instance.host, + port=instance.port, + weight=instance.weight, + check_healthy=instance.check_healthy, + healthy=instance.healthy, + enabled=instance.enabled, + prompt_template=instance.prompt_template, + last_heartbeat=instance.last_heartbeat, + ) + + @classmethod + def to_model_instance(cls, item: "ModelInstanceStorageItem") -> ModelInstance: + return ModelInstance( + model_name=item.model_name, + host=item.host, + port=item.port, + weight=item.weight, + check_healthy=item.check_healthy, + healthy=item.healthy, + enabled=item.enabled, + prompt_template=item.prompt_template, + last_heartbeat=item.last_heartbeat, + ) + + +class StorageModelRegistry(ModelRegistry): + def __init__( + self, + storage: StorageInterface, + system_app: SystemApp | None = None, + executor: Optional[Executor] = None, + heartbeat_interval_secs: float | int = 60, + heartbeat_timeout_secs: int = 120, + ): + super().__init__(system_app) + self._storage = storage + self._executor = executor or ThreadPoolExecutor(max_workers=2) + self.heartbeat_interval_secs = heartbeat_interval_secs + self.heartbeat_timeout_secs = heartbeat_timeout_secs + self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker) + self.heartbeat_thread.daemon = True + self.heartbeat_thread.start() + + @classmethod + def from_url( + cls, + db_url: str, + db_name: str, + pool_size: int = 5, + max_overflow: int = 10, + try_to_create_db: bool = False, + **kwargs, + ) -> "StorageModelRegistry": + from dbgpt.storage.metadata.db_manager import DatabaseManager, initialize_db + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .db_storage import ModelInstanceEntity, ModelInstanceItemAdapter + + engine_args = { + "pool_size": pool_size, + "max_overflow": max_overflow, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + } + + db: DatabaseManager = initialize_db( + db_url, db_name, engine_args, try_to_create_db=try_to_create_db + ) + storage_adapter = ModelInstanceItemAdapter() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + db, + ModelInstanceEntity, + storage_adapter, + serializer, + ) + return cls(storage, **kwargs) + + async def _get_instances_by_model( + self, model_name: str, host: str, port: int, healthy_only: bool = False + ) -> Tuple[List[ModelInstanceStorageItem], List[ModelInstanceStorageItem]]: + query_spec = QuerySpec(conditions={"model_name": model_name}) + # Query all instances of the model + instances = await blocking_func_to_async( + self._executor, self._storage.query, query_spec, ModelInstanceStorageItem + ) + if healthy_only: + instances = [ins for ins in instances if ins.healthy is True] + exist_ins = [ins for ins in instances if ins.host == host and ins.port == port] + return instances, exist_ins + + def _heartbeat_checker(self): + while True: + all_instances: List[ModelInstanceStorageItem] = self._storage.query( + QuerySpec(conditions={}), ModelInstanceStorageItem + ) + for instance in all_instances: + if ( + instance.check_healthy + and datetime.now() - instance.last_heartbeat + > timedelta(seconds=self.heartbeat_timeout_secs) + ): + instance.healthy = False + self._storage.update(instance) + time.sleep(self.heartbeat_interval_secs) + + async def register_instance(self, instance: ModelInstance) -> bool: + model_name = instance.model_name.strip() + host = instance.host.strip() + port = instance.port + _, exist_ins = await self._get_instances_by_model( + model_name, host, port, healthy_only=False + ) + if exist_ins: + # Exist instances, just update the instance + # One exist instance at most + ins: ModelInstanceStorageItem = exist_ins[0] + # Update instance + ins.weight = instance.weight + ins.healthy = True + ins.prompt_template = instance.prompt_template + ins.last_heartbeat = datetime.now() + await blocking_func_to_async(self._executor, self._storage.update, ins) + else: + # No exist instance, save the new instance + new_inst = ModelInstanceStorageItem.from_model_instance(instance) + new_inst.healthy = True + new_inst.last_heartbeat = datetime.now() + await blocking_func_to_async(self._executor, self._storage.save, new_inst) + return True + + async def deregister_instance(self, instance: ModelInstance) -> bool: + """Deregister a model instance. + + If the instance exists, set the instance as unhealthy, nothing to do if the + instance does not exist. + + Args: + instance (ModelInstance): The instance to deregister. + """ + model_name = instance.model_name.strip() + host = instance.host.strip() + port = instance.port + _, exist_ins = await self._get_instances_by_model( + model_name, host, port, healthy_only=False + ) + if exist_ins: + ins = exist_ins[0] + ins.healthy = False + await blocking_func_to_async(self._executor, self._storage.update, ins) + return True + + async def get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + """Get all instances of a model(Async). + + Args: + model_name (str): The model name. + healthy_only (bool): Whether only get healthy instances. Defaults to False. + """ + return await blocking_func_to_async( + self._executor, self.sync_get_all_instances, model_name, healthy_only + ) + + def sync_get_all_instances( + self, model_name: str, healthy_only: bool = False + ) -> List[ModelInstance]: + """Get all instances of a model. + + Args: + model_name (str): The model name. + healthy_only (bool): Whether only get healthy instances. Defaults to False. + + Returns: + List[ModelInstance]: The list of instances. + """ + instances = self._storage.query( + QuerySpec(conditions={"model_name": model_name}), ModelInstanceStorageItem + ) + if healthy_only: + instances = [ins for ins in instances if ins.healthy is True] + return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances] + + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + """Get all model instances. + + Args: + healthy_only (bool): Whether only get healthy instances. Defaults to False. + + Returns: + List[ModelInstance]: The list of instances. + """ + all_instances = await blocking_func_to_async( + self._executor, + self._storage.query, + QuerySpec(conditions={}), + ModelInstanceStorageItem, + ) + if healthy_only: + all_instances = [ins for ins in all_instances if ins.healthy is True] + return [ + ModelInstanceStorageItem.to_model_instance(ins) for ins in all_instances + ] + + async def send_heartbeat(self, instance: ModelInstance) -> bool: + """Receive heartbeat from model instance. + + Update the last heartbeat time of the instance. If the instance does not exist, + register the instance. + + Args: + instance (ModelInstance): The instance to send heartbeat. + + Returns: + bool: True if the heartbeat is received successfully. + """ + model_name = instance.model_name.strip() + host = instance.host.strip() + port = instance.port + _, exist_ins = await self._get_instances_by_model( + model_name, host, port, healthy_only=False + ) + if not exist_ins: + # register new instance from heartbeat + await self.register_instance(instance) + return True + else: + ins = exist_ins[0] + ins.last_heartbeat = datetime.now() + ins.healthy = True + await blocking_func_to_async(self._executor, self._storage.update, ins) + return True diff --git a/dbgpt/model/cluster/tests/registry_impl/__init__.py b/dbgpt/model/cluster/tests/registry_impl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/model/cluster/tests/registry_impl/test_storage_registry.py b/dbgpt/model/cluster/tests/registry_impl/test_storage_registry.py new file mode 100644 index 000000000..6158a8c25 --- /dev/null +++ b/dbgpt/model/cluster/tests/registry_impl/test_storage_registry.py @@ -0,0 +1,221 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from dbgpt.core.interface.storage import InMemoryStorage, QuerySpec +from dbgpt.util.serialization.json_serialization import JsonSerializer + +from ...registry_impl.storage import ( + ModelInstance, + ModelInstanceStorageItem, + StorageModelRegistry, +) + + +@pytest.fixture +def in_memory_storage(): + return InMemoryStorage(serializer=JsonSerializer()) + + +@pytest.fixture +def thread_pool_executor(): + return ThreadPoolExecutor(max_workers=2) + + +@pytest.fixture +def registry(in_memory_storage, thread_pool_executor): + return StorageModelRegistry( + storage=in_memory_storage, + executor=thread_pool_executor, + heartbeat_interval_secs=1, + heartbeat_timeout_secs=2, + ) + + +@pytest.fixture +def model_instance(): + return ModelInstance( + model_name="test_model", + host="localhost", + port=8080, + weight=1.0, + check_healthy=True, + healthy=True, + enabled=True, + prompt_template=None, + last_heartbeat=datetime.now(), + ) + + +@pytest.fixture +def model_instance_2(): + return ModelInstance( + model_name="test_model", + host="localhost", + port=8081, + weight=1.0, + check_healthy=True, + healthy=True, + enabled=True, + prompt_template=None, + last_heartbeat=datetime.now(), + ) + + +@pytest.fixture +def model_instance_3(): + return ModelInstance( + model_name="test_model_2", + host="localhost", + port=8082, + weight=1.0, + check_healthy=True, + healthy=True, + enabled=True, + prompt_template=None, + last_heartbeat=datetime.now(), + ) + + +@pytest.fixture +def model_instance_storage_item(model_instance): + return ModelInstanceStorageItem.from_model_instance(model_instance) + + +@pytest.mark.asyncio +async def test_register_instance_new(registry, model_instance): + """Test registering a new model instance.""" + result = await registry.register_instance(model_instance) + + assert result is True + instances = await registry.get_all_instances(model_instance.model_name) + assert len(instances) == 1 + saved_instance = instances[0] + assert saved_instance.model_name == model_instance.model_name + assert saved_instance.host == model_instance.host + assert saved_instance.port == model_instance.port + assert saved_instance.healthy is True + assert saved_instance.last_heartbeat is not None + + +@pytest.mark.asyncio +async def test_register_instance_existing( + registry, model_instance, model_instance_storage_item +): + """Test registering an existing model instance and updating it.""" + await registry.register_instance(model_instance) + + # Register the instance again with updated heartbeat + result = await registry.register_instance(model_instance) + + assert result is True + instances = await registry.get_all_instances(model_instance.model_name) + assert len(instances) == 1 + updated_instance = instances[0] + assert updated_instance.model_name == model_instance.model_name + assert updated_instance.host == model_instance.host + assert updated_instance.port == model_instance.port + assert updated_instance.healthy is True + assert updated_instance.last_heartbeat is not None + + +@pytest.mark.asyncio +async def test_deregister_instance(registry, model_instance): + """Test deregistering a model instance.""" + await registry.register_instance(model_instance) + + result = await registry.deregister_instance(model_instance) + + assert result is True + instances = await registry.get_all_instances(model_instance.model_name) + assert len(instances) == 1 + deregistered_instance = instances[0] + assert deregistered_instance.healthy is False + + +@pytest.mark.asyncio +async def test_get_all_instances(registry, model_instance): + """Test retrieving all model instances.""" + await registry.register_instance(model_instance) + + result = await registry.get_all_instances( + model_instance.model_name, healthy_only=True + ) + + assert len(result) == 1 + assert result[0].model_name == model_instance.model_name + + +def test_sync_get_all_instances(registry, model_instance): + """Test synchronously retrieving all model instances.""" + registry.sync_get_all_instances(model_instance.model_name, healthy_only=True) + registry._storage.save(ModelInstanceStorageItem.from_model_instance(model_instance)) + + result = registry.sync_get_all_instances( + model_instance.model_name, healthy_only=True + ) + + assert len(result) == 1 + assert result[0].model_name == model_instance.model_name + + +@pytest.mark.asyncio +async def test_send_heartbeat_new_instance(registry, model_instance): + """Test sending a heartbeat for a new instance.""" + result = await registry.send_heartbeat(model_instance) + + assert result is True + instances = await registry.get_all_instances(model_instance.model_name) + assert len(instances) == 1 + saved_instance = instances[0] + assert saved_instance.model_name == model_instance.model_name + + +@pytest.mark.asyncio +async def test_send_heartbeat_existing_instance(registry, model_instance): + """Test sending a heartbeat for an existing instance.""" + await registry.register_instance(model_instance) + + # Send heartbeat to update the instance + result = await registry.send_heartbeat(model_instance) + + assert result is True + instances = await registry.get_all_instances(model_instance.model_name) + assert len(instances) == 1 + updated_instance = instances[0] + assert updated_instance.last_heartbeat > model_instance.last_heartbeat + + +@pytest.mark.asyncio +async def test_heartbeat_checker( + in_memory_storage, thread_pool_executor, model_instance +): + """Test the heartbeat checker mechanism.""" + heartbeat_timeout_secs = 1 + registry = StorageModelRegistry( + storage=in_memory_storage, + executor=thread_pool_executor, + heartbeat_interval_secs=0.1, + heartbeat_timeout_secs=heartbeat_timeout_secs, + ) + + async def check_heartbeat(model_name: str, expected_healthy: bool): + instances = await registry.get_all_instances(model_name) + assert len(instances) == 1 + updated_instance = instances[0] + assert updated_instance.healthy == expected_healthy + + await registry.register_instance(model_instance) + # First heartbeat should be successful + await check_heartbeat(model_instance.model_name, True) + # Wait heartbeat timeout + await asyncio.sleep(heartbeat_timeout_secs + 0.5) + await check_heartbeat(model_instance.model_name, False) + + # Send heartbeat again + await registry.send_heartbeat(model_instance) + # Should be healthy again + await check_heartbeat(model_instance.model_name, True) diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index 9bc2a80c7..449f52ac6 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -1059,6 +1059,10 @@ def initialize_worker_manager_in_client( if not app: raise Exception("app can't be None") + if system_app: + logger.info(f"Register WorkerManager {_DefaultWorkerManagerFactory.name}") + system_app.register(_DefaultWorkerManagerFactory, worker_manager) + worker_params: ModelWorkerParameters = _parse_worker_params( model_name=model_name, model_path=model_path, controller_addr=controller_addr ) @@ -1104,8 +1108,6 @@ def initialize_worker_manager_in_client( if include_router and app: # mount WorkerManager router app.include_router(router, prefix="/api") - if system_app: - system_app.register(_DefaultWorkerManagerFactory, worker_manager) def run_worker_manager( diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index f3396d564..5aad75a50 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -55,6 +55,84 @@ class ModelControllerParameters(BaseParameters): port: Optional[int] = field( default=8000, metadata={"help": "Model Controller deploy port"} ) + registry_type: Optional[str] = field( + default="embedded", + metadata={ + "help": "Registry type: embedded, database...", + "valid_values": ["embedded", "database"], + }, + ) + registry_db_type: Optional[str] = field( + default="mysql", + metadata={ + "help": "Registry database type, now only support sqlite and mysql, it is " + "valid when registry_type is database", + "valid_values": ["mysql", "sqlite"], + }, + ) + registry_db_name: Optional[str] = field( + default="dbgpt", + metadata={ + "help": "Registry database name, just for database, it is valid when " + "registry_type is database, please set to full database path for sqlite" + }, + ) + registry_db_host: Optional[str] = field( + default=None, + metadata={ + "help": "Registry database host, just for database, it is valid when " + "registry_type is database" + }, + ) + registry_db_port: Optional[int] = field( + default=None, + metadata={ + "help": "Registry database port, just for database, it is valid when " + "registry_type is database" + }, + ) + registry_db_user: Optional[str] = field( + default=None, + metadata={ + "help": "Registry database user, just for database, it is valid when " + "registry_type is database" + }, + ) + registry_db_password: Optional[str] = field( + default=None, + metadata={ + "help": "Registry database password, just for database, it is valid when " + "registry_type is database. We recommend to use environment variable to " + "store password, you can set it in your environment variable like " + "export CONTROLLER_REGISTRY_DB_PASSWORD='your_password'" + }, + ) + registry_db_pool_size: Optional[int] = field( + default=5, + metadata={ + "help": "Registry database pool size, just for database, it is valid when " + "registry_type is database" + }, + ) + registry_db_max_overflow: Optional[int] = field( + default=10, + metadata={ + "help": "Registry database max overflow, just for database, it is valid " + "when registry_type is database" + }, + ) + + heartbeat_interval_secs: Optional[int] = field( + default=20, metadata={"help": "The interval for checking heartbeats (seconds)"} + ) + heartbeat_timeout_secs: Optional[int] = field( + default=60, + metadata={ + "help": "The timeout for checking heartbeats (seconds), it will be set " + "unhealthy if the worker is not responding in this time" + }, + ) + daemon: Optional[bool] = field( default=False, metadata={"help": "Run Model Controller in background"} ) diff --git a/dbgpt/serve/core/tests/conftest.py b/dbgpt/serve/core/tests/conftest.py index efc02cf47..b4a0d1c5e 100644 --- a/dbgpt/serve/core/tests/conftest.py +++ b/dbgpt/serve/core/tests/conftest.py @@ -3,7 +3,7 @@ from typing import Dict import pytest import pytest_asyncio from fastapi.middleware.cors import CORSMiddleware -from httpx import AsyncClient +from httpx import ASGITransport, AsyncClient from dbgpt.component import SystemApp from dbgpt.util import AppConfig @@ -56,7 +56,9 @@ async def client(request, asystem_app: SystemApp): test_app = asystem_app.app - async with AsyncClient(app=test_app, base_url=base_url, headers=headers) as client: + async with AsyncClient( + transport=ASGITransport(test_app), base_url=base_url, headers=headers + ) as client: for router in routers: test_app.include_router(router) if app_caller: diff --git a/dbgpt/serve/prompt/tests/test_endpoints.py b/dbgpt/serve/prompt/tests/test_endpoints.py index f9fd1db74..e73a456c4 100644 --- a/dbgpt/serve/prompt/tests/test_endpoints.py +++ b/dbgpt/serve/prompt/tests/test_endpoints.py @@ -86,6 +86,7 @@ async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): ) async def test_api_auth(client: AsyncClient): response = await client.get("/health") + response.raise_for_status() assert response.status_code == 200 assert response.json() == {"status": "ok"} diff --git a/dbgpt/util/api_utils.py b/dbgpt/util/api_utils.py index 175cd55de..3312aaafc 100644 --- a/dbgpt/util/api_utils.py +++ b/dbgpt/util/api_utils.py @@ -1,13 +1,156 @@ +import asyncio import logging +import threading +import time +from abc import ABC +from concurrent.futures import Executor, ThreadPoolExecutor from dataclasses import asdict, is_dataclass +from datetime import datetime, timedelta from inspect import signature -from typing import List, Optional, Tuple, Type, TypeVar, Union, get_type_hints +from typing import List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints T = TypeVar("T") logger = logging.getLogger(__name__) +class APIMixin(ABC): + """API mixin class.""" + + def __init__( + self, + urls: Union[str, List[str]], + health_check_path: str, + health_check_interval_secs: int = 5, + health_check_timeout_secs: int = 30, + check_health: bool = True, + choice_type: Literal["latest_first", "random"] = "latest_first", + executor: Optional[Executor] = None, + ): + if isinstance(urls, str): + # Split by "," + urls = urls.split(",") + urls = [url.strip() for url in urls] + self._remote_urls = urls + self._health_check_path = health_check_path + self._health_urls = [] + self._health_check_interval_secs = health_check_interval_secs + self._health_check_timeout_secs = health_check_timeout_secs + self._heartbeat_map = {} + self._choice_type = choice_type + self._heartbeat_thread = threading.Thread(target=self._heartbeat_checker) + self._heartbeat_executor = executor or ThreadPoolExecutor(max_workers=3) + self._heartbeat_stop_event = threading.Event() + + if check_health: + self._heartbeat_thread.daemon = True + self._heartbeat_thread.start() + + def _heartbeat_checker(self): + logger.debug("Running health check") + while not self._heartbeat_stop_event.is_set(): + try: + healthy_urls = self._check_and_update_health() + logger.debug(f"Healthy urls: {healthy_urls}") + except Exception as e: + logger.warning(f"Health check failed, error: {e}") + time.sleep(self._health_check_interval_secs) + + def __del__(self): + + self._heartbeat_stop_event.set() + + def _check_health(self, url: str) -> Tuple[bool, str]: + try: + import requests + + logger.debug(f"Checking health for {url}") + req_url = url + self._health_check_path + response = requests.get(req_url, timeout=10) + return response.status_code == 200, url + except Exception as e: + logger.warning(f"Health check failed for {url}, error: {e}") + return False, url + + def _check_and_update_health(self) -> List[str]: + """Check health of all remote urls and update the health urls list.""" + check_tasks = [] + check_results = [] + for url in self._remote_urls: + check_tasks.append(self._heartbeat_executor.submit(self._check_health, url)) + for task in check_tasks: + check_results.append(task.result()) + now = datetime.now() + for is_healthy, url in check_results: + if is_healthy: + self._heartbeat_map[url] = now + healthy_urls = [] + for url, last_heartbeat in self._heartbeat_map.items(): + if now - last_heartbeat < timedelta( + seconds=self._health_check_interval_secs + ): + healthy_urls.append((url, last_heartbeat)) + # Sort by last heartbeat time, latest first + healthy_urls.sort(key=lambda x: x[1], reverse=True) + + self._health_urls = [url for url, _ in healthy_urls] + return self._health_urls + + async def select_url(self, max_wait_health_timeout_secs: int = 2) -> str: + """Select a healthy url to send request. + + If no healthy urls found, select randomly. + """ + import random + + def _select(urls: List[str]): + if self._choice_type == "latest_first": + return urls[0] + elif self._choice_type == "random": + return random.choice(urls) + else: + raise ValueError(f"Invalid choice type: {self._choice_type}") + + if self._health_urls: + return _select(self._health_urls) + elif max_wait_health_timeout_secs > 0: + start_time = datetime.now() + while datetime.now() - start_time < timedelta( + seconds=max_wait_health_timeout_secs + ): + if self._health_urls: + return _select(self._health_urls) + await asyncio.sleep(0.1) + logger.warning("No healthy urls found, selecting randomly") + return _select(self._remote_urls) + + def sync_select_url(self, max_wait_health_timeout_secs: int = 2) -> str: + """Synchronous version of select_url.""" + import random + import time + + def _select(urls: List[str]): + if self._choice_type == "latest_first": + return urls[0] + elif self._choice_type == "random": + return random.choice(urls) + else: + raise ValueError(f"Invalid choice type: {self._choice_type}") + + if self._health_urls: + return _select(self._health_urls) + elif max_wait_health_timeout_secs > 0: + start_time = datetime.now() + while datetime.now() - start_time < timedelta( + seconds=max_wait_health_timeout_secs + ): + if self._health_urls: + return _select(self._health_urls) + time.sleep(0.1) + logger.warning("No healthy urls found, selecting randomly") + return _select(self._remote_urls) + + def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: import typing_inspect @@ -17,7 +160,7 @@ def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: return None -def _build_request(self, func, path, method, *args, **kwargs): +def _build_request(self, base_url, func, path, method, *args, **kwargs): return_type = get_type_hints(func).get("return") if return_type is None: raise TypeError("Return type must be annotated in the decorated function.") @@ -27,7 +170,6 @@ def _build_request(self, func, path, method, *args, **kwargs): if not actual_dataclass: actual_dataclass = return_type sig = signature(func) - base_url = self.base_url # Get base_url from class instance bound = sig.bind(self, *args, **kwargs) bound.apply_defaults() @@ -61,13 +203,22 @@ def _build_request(self, func, path, method, *args, **kwargs): return return_type, actual_dataclass, request_params -def _api_remote(path, method="GET"): +def _api_remote(path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2): def decorator(func): async def wrapper(self, *args, **kwargs): import httpx + if not isinstance(self, APIMixin): + raise TypeError( + "The class must inherit from APIMixin to use the @_api_remote " + "decorator." + ) + # Found a healthy url to send request + base_url = await self.select_url( + max_wait_health_timeout_secs=max_wait_health_timeout_secs + ) return_type, actual_dataclass, request_params = _build_request( - self, func, path, method, *args, **kwargs + self, base_url, func, path, method, *args, **kwargs ) async with httpx.AsyncClient() as client: response = await client.request(**request_params) @@ -84,13 +235,24 @@ def _api_remote(path, method="GET"): return decorator -def _sync_api_remote(path, method="GET"): +def _sync_api_remote( + path: str, method: str = "GET", max_wait_health_timeout_secs: int = 2 +): def decorator(func): def wrapper(self, *args, **kwargs): import requests + if not isinstance(self, APIMixin): + raise TypeError( + "The class must inherit from APIMixin to use the @_sync_api_remote " + "decorator." + ) + base_url = self.sync_select_url( + max_wait_health_timeout_secs=max_wait_health_timeout_secs + ) + return_type, actual_dataclass, request_params = _build_request( - self, func, path, method, *args, **kwargs + self, base_url, func, path, method, *args, **kwargs ) response = requests.request(**request_params) diff --git a/dbgpt/util/tests/test_api_utils.py b/dbgpt/util/tests/test_api_utils.py new file mode 100644 index 000000000..2d0695dc3 --- /dev/null +++ b/dbgpt/util/tests/test_api_utils.py @@ -0,0 +1,105 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from ..api_utils import APIMixin + + +# Mock requests.get +@pytest.fixture +def mock_requests_get(): + with patch("requests.get") as mock_get: + yield mock_get + + +@pytest.fixture +def apimixin(): + urls = "http://example.com,http://example2.com" + health_check_path = "/health" + apimixin = APIMixin(urls, health_check_path) + yield apimixin + # Ensure the executor is properly shut down after tests + apimixin._heartbeat_executor.shutdown(wait=False) + + +def test_apimixin_initialization(apimixin): + """Test APIMixin initialization with various parameters.""" + assert apimixin._remote_urls == ["http://example.com", "http://example2.com"] + assert apimixin._health_check_path == "/health" + assert apimixin._health_check_interval_secs == 5 + assert apimixin._health_check_timeout_secs == 30 + assert apimixin._choice_type == "latest_first" + assert isinstance(apimixin._heartbeat_executor, ThreadPoolExecutor) + + +def test_health_check(apimixin, mock_requests_get): + """Test the _check_health method.""" + url = "http://example.com" + + # Mocking a successful response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests_get.return_value = mock_response + + is_healthy, checked_url = apimixin._check_health(url) + assert is_healthy + assert checked_url == url + + # Mocking a failed response + mock_requests_get.side_effect = Exception("Connection error") + is_healthy, checked_url = apimixin._check_health(url) + assert not is_healthy + assert checked_url == url + + +def test_check_and_update_health(apimixin, mock_requests_get): + """Test the _check_and_update_health method.""" + apimixin._heartbeat_map = { + "http://example.com": datetime.now() - timedelta(seconds=3), + "http://example2.com": datetime.now() - timedelta(seconds=10), + } + + # Mocking responses + def side_effect(url, timeout): + mock_response = MagicMock() + if url == "http://example.com/health": + mock_response.status_code = 200 + elif url == "http://example2.com/health": + mock_response.status_code = 500 + return mock_response + + mock_requests_get.side_effect = side_effect + + health_urls = apimixin._check_and_update_health() + assert "http://example.com" in health_urls + assert "http://example2.com" not in health_urls + + +@pytest.mark.asyncio +async def test_select_url(apimixin, mock_requests_get): + """Test the async select_url method.""" + apimixin._health_urls = ["http://example.com"] + + selected_url = await apimixin.select_url() + assert selected_url == "http://example.com" + + # Test with no healthy URLs + apimixin._health_urls = [] + selected_url = await apimixin.select_url(max_wait_health_timeout_secs=1) + assert selected_url in ["http://example.com", "http://example2.com"] + + +def test_sync_select_url(apimixin, mock_requests_get): + """Test the synchronous sync_select_url method.""" + apimixin._health_urls = ["http://example.com"] + + selected_url = apimixin.sync_select_url() + assert selected_url == "http://example.com" + + # Test with no healthy URLs + apimixin._health_urls = [] + selected_url = apimixin.sync_select_url(max_wait_health_timeout_secs=1) + assert selected_url in ["http://example.com", "http://example2.com"] diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index 03ae1473d..0dd87543c 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -172,7 +172,7 @@ def setup_http_service_logging(exclude_paths: Optional[List[str]] = None): """ if not exclude_paths: # Not show heartbeat log - exclude_paths = ["/api/controller/heartbeat"] + exclude_paths = ["/api/controller/heartbeat", "/api/health"] uvicorn_logger = logging.getLogger("uvicorn.access") if uvicorn_logger: for path in exclude_paths: diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile index 268913683..c30fa7981 100644 --- a/docker/base/Dockerfile +++ b/docker/base/Dockerfile @@ -34,15 +34,16 @@ RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \ # install openai for proxyllm && pip3 install -i $PIP_INDEX_URL ".[openai]" -RUN (if [ "${LANGUAGE}" = "zh" ]; \ - # language is zh, download zh_core_web_sm from github - then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.7.0-py3-none-any.whl \ - && pip3 install /tmp/zh_core_web_sm-3.7.0-py3-none-any.whl -i $PIP_INDEX_URL \ - && rm /tmp/zh_core_web_sm-3.7.0-py3-none-any.whl; \ - # not zh, download directly - else python3 -m spacy download zh_core_web_sm; \ - fi;) \ - && rm -rf `pip3 cache dir` +# Not install spacy model for now +#RUN (if [ "${LANGUAGE}" = "zh" ]; \ +# # language is zh, download zh_core_web_sm from github +# then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.7.0-py3-none-any.whl \ +# && pip3 install /tmp/zh_core_web_sm-3.7.0-py3-none-any.whl -i $PIP_INDEX_URL \ +# && rm /tmp/zh_core_web_sm-3.7.0-py3-none-any.whl; \ +# # not zh, download directly +# else python3 -m spacy download zh_core_web_sm; \ +# fi;) \ +# && rm -rf `pip3 cache dir` ARG BUILD_LOCAL_CODE="false" # COPY the rest of the app diff --git a/docker/base/build_proxy_image.sh b/docker/base/build_proxy_image.sh new file mode 100755 index 000000000..c3306f1db --- /dev/null +++ b/docker/base/build_proxy_image.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +SCRIPT_LOCATION=$0 +cd "$(dirname "$SCRIPT_LOCATION")" +WORK_DIR=$(pwd) + +if [[ " $* " == *" --help "* ]] || [[ " $* " == *" -h "* ]]; then + bash $WORK_DIR/build_image.sh "$@" + exit 0 +fi + +bash $WORK_DIR/build_image.sh --install-mode openai "$@" + +if [ 0 -ne $? ]; then + echo "Error: build base image failed" + exit 1 +fi \ No newline at end of file diff --git a/docker/compose_examples/ha-cluster-docker-compose.yml b/docker/compose_examples/ha-cluster-docker-compose.yml new file mode 100644 index 000000000..412a7ab7a --- /dev/null +++ b/docker/compose_examples/ha-cluster-docker-compose.yml @@ -0,0 +1,127 @@ +# An example of using docker-compose to start a HA model serving cluster with two controllers and one worker. +# For simplicity, we use chatgpt_proxyllm as the model for the worker, and we build a new docker image named eosphorosai/dbgpt-openai:latest. +# How to build the image: +# run `bash ./docker/base/build_proxy_image.sh` in the root directory of the project. +# If you want to use other pip index url, you can run command with `--pip-index-url` option. +# For example, `bash ./docker/base/build_proxy_image.sh --pip-index-url https://pypi.tuna.tsinghua.edu.cn/simple` +# +# How to start the cluster: +# 1. run `cd docker/compose_examples` +# 2. run `OPENAI_API_KEY="{your api key}" OPENAI_API_BASE="https://api.openai.com/v1" docker compose -f ha-cluster-docker-compose.yml up -d` +# Note: Make sure you have set the environment variables OPENAI_API_KEY. +version: '3.10' + +services: + init: + image: busybox + volumes: + - ../examples/sqls:/sqls + - ../../assets/schema/dbgpt.sql:/dbgpt.sql + - dbgpt-init-scripts:/docker-entrypoint-initdb.d + command: /bin/sh -c "cp /dbgpt.sql /docker-entrypoint-initdb.d/ && cp /sqls/* /docker-entrypoint-initdb.d/ && ls /docker-entrypoint-initdb.d/" + + db: + image: mysql/mysql-server + environment: + MYSQL_USER: 'user' + MYSQL_PASSWORD: 'password' + MYSQL_ROOT_PASSWORD: 'aa123456' + ports: + - 3306:3306 + volumes: + - dbgpt-myql-db:/var/lib/mysql + - ../examples/my.cnf:/etc/my.cnf + - dbgpt-init-scripts:/docker-entrypoint-initdb.d + restart: unless-stopped + networks: + - dbgptnet + depends_on: + - init + controller-1: + image: eosphorosai/dbgpt-openai:latest + command: dbgpt start controller --registry_type database --registry_db_type mysql --registry_db_name dbgpt --registry_db_host db --registry_db_port 3306 --registry_db_user root --registry_db_password aa123456 + volumes: + - ../../:/app + restart: unless-stopped + networks: + - dbgptnet + depends_on: + - db + controller-2: + image: eosphorosai/dbgpt-openai:latest + command: dbgpt start controller --registry_type database --registry_db_type mysql --registry_db_name dbgpt --registry_db_host db --registry_db_port 3306 --registry_db_user root --registry_db_password aa123456 + volumes: + - ../../:/app + restart: unless-stopped + networks: + - dbgptnet + depends_on: + - db + llm-worker: + image: eosphorosai/dbgpt-openai:latest + command: dbgpt start worker --model_type proxy --model_name chatgpt_proxyllm --model_path chatgpt_proxyllm --proxy_server_url ${OPENAI_API_BASE}/chat/completions --proxy_api_key ${OPENAI_API_KEY} --controller_addr "http://controller-1:8000,http://controller-2:8000" + environment: + - DBGPT_LOG_LEVEL=DEBUG + # Your real openai model name, e.g. gpt-3.5-turbo, gpt-4o + - PROXYLLM_BACKEND=gpt-3.5-turbo + depends_on: + - controller-1 + - controller-2 + volumes: + - ../../:/app + restart: unless-stopped + networks: + - dbgptnet + ipc: host + embedding-worker: + image: eosphorosai/dbgpt-openai:latest + command: dbgpt start worker --worker_type text2vec --model_name proxy_http_openapi --model_path proxy_http_openapi --proxy_server_url ${OPENAI_API_BASE}/embeddings --proxy_api_key ${OPENAI_API_KEY} --controller_addr "http://controller-1:8000,http://controller-2:8000" + environment: + - DBGPT_LOG_LEVEL=DEBUG + - proxy_http_openapi_proxy_backend=text-embedding-3-small + depends_on: + - controller-1 + - controller-2 + volumes: + - ../../:/app + restart: unless-stopped + networks: + - dbgptnet + ipc: host + webserver: + image: eosphorosai/dbgpt-openai:latest + command: dbgpt start webserver --light --remote_embedding --controller_addr "http://controller-1:8000,http://controller-2:8000" + environment: + - DBGPT_LOG_LEVEL=DEBUG + - LOCAL_DB_TYPE=mysql + - LOCAL_DB_HOST=db + - LOCAL_DB_USER=root + - LOCAL_DB_PASSWORD=aa123456 + - LLM_MODEL=chatgpt_proxyllm + - EMBEDDING_MODEL=proxy_http_openapi + depends_on: + - controller-1 + - controller-2 + - llm-worker + - embedding-worker + volumes: + - ../../:/app + - dbgpt-data:/app/pilot/data + - dbgpt-message:/app/pilot/message + # env_file: + # - .env.template + ports: + - 5670:5670/tcp + # webserver may be failed, it must wait all sqls in /docker-entrypoint-initdb.d execute finish. + restart: unless-stopped + networks: + - dbgptnet +volumes: + dbgpt-init-scripts: + dbgpt-myql-db: + dbgpt-data: + dbgpt-message: +networks: + dbgptnet: + driver: bridge + name: dbgptnet \ No newline at end of file diff --git a/docs/docs/api/datasource.md b/docs/docs/api/datasource.md index 7553bb673..e6e6cff23 100644 --- a/docs/docs/api/datasource.md +++ b/docs/docs/api/datasource.md @@ -34,7 +34,7 @@ curl -X POST "http://localhost:5670/api/v2/chat/completions" \ -H "Authorization: Bearer $DBGPT_API_KEY" \ -H "accept: application/json" \ -H "Content-Type: application/json" \ - -d "{\"messages\":\"show space datas limit 5\",\"model\":\"chatgpt_proxyllm\", \"chat_mode\": \"chat_datasource\", \"chat_param\": \"$DB_NAME\"}" + -d "{\"messages\":\"show space datas limit 5\",\"model\":\"chatgpt_proxyllm\", \"chat_mode\": \"chat_data\", \"chat_param\": \"$DB_NAME\"}" ``` diff --git a/docs/docs/awel/get_started.ipynb b/docs/docs/awel/get_started.ipynb index 4b46adc2a..063cde67c 100644 --- a/docs/docs/awel/get_started.ipynb +++ b/docs/docs/awel/get_started.ipynb @@ -47,11 +47,11 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from dbgpt._private.pydantic import BaseModel, Field\n", "from dbgpt.core.awel import DAG, HttpTrigger, MapOperator" - ] + ], + "outputs": [] }, { "attachments": {}, @@ -67,12 +67,12 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "class TriggerReqBody(BaseModel):\n", " name: str = Field(..., description=\"User name\")\n", " age: int = Field(18, description=\"User age\")" - ] + ], + "outputs": [] }, { "attachments": {}, @@ -87,7 +87,6 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "class RequestHandleOperator(MapOperator[TriggerReqBody, str]):\n", " def __init__(self, **kwargs):\n", @@ -96,7 +95,8 @@ " async def map(self, input_value: TriggerReqBody) -> str:\n", " print(f\"Receive input value: {input_value}\")\n", " return f\"Hello, {input_value.name}, your age is {input_value.age}\"" - ] + ], + "outputs": [] }, { "attachments": {}, @@ -112,13 +112,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "with DAG(\"simple_dag_example\") as dag:\n", " trigger = HttpTrigger(\"/examples/hello\", request_body=TriggerReqBody)\n", " map_node = RequestHandleOperator()\n", " trigger >> map_node" - ] + ], + "outputs": [] }, { "attachments": {}, @@ -138,12 +138,12 @@ "languageId": "powershell" } }, - "outputs": [], "source": [ "\n", "% curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\\?name\\=zhangsan\n", "\"Hello, zhangsan, your age is 18\"" - ] + ], + "outputs": [] }, { "attachments": {}, @@ -161,7 +161,6 @@ "languageId": "powershell" } }, - "outputs": [], "source": [ "if __name__ == \"__main__\":\n", " if dag.leaf_nodes[0].dev_mode:\n", @@ -171,7 +170,8 @@ " else:\n", " # Production mode, DB-GPT will automatically load and execute the current file after startup.\n", " pass" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -181,11 +181,11 @@ "languageId": "powershell" } }, - "outputs": [], "source": [ "curl -X GET http://127.0.0.1:5555/api/v1/awel/trigger/examples/hello\\?name\\=zhangsan\n", "\"Hello, zhangsan, your age is 18\"" - ] + ], + "outputs": [] }, { "attachments": {}, diff --git a/docs/docs/installation/model_service/cluster.md b/docs/docs/installation/model_service/cluster.md index fea91a792..ade43c48a 100644 --- a/docs/docs/installation/model_service/cluster.md +++ b/docs/docs/installation/model_service/cluster.md @@ -124,7 +124,7 @@ MODEL_SERVER=http://127.0.0.1:8000 Or it can be started directly by command to formulate the model. ```shell -LLM_MODEL=glm-4-9b-chat dbgpt start webserver --light +LLM_MODEL=glm-4-9b-chat dbgpt start webserver --light --remote_embedding ``` ## Command line usage diff --git a/docs/docs/installation/model_service/cluster_ha.md b/docs/docs/installation/model_service/cluster_ha.md new file mode 100644 index 000000000..8ca7363b1 --- /dev/null +++ b/docs/docs/installation/model_service/cluster_ha.md @@ -0,0 +1,171 @@ +# High Availability + + +## Architecture + +Here is the architecture of the high availability cluster, more details can be found in +the [cluster deployment](/docs/latest/installation/model_service/cluster) mode and [SMMF](/docs/latest/modules/smmf) module. + +
+
+