From 88a5c576469b0db2f334b1dc54eee755c55bb448 Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 3 May 2023 22:10:35 +0800 Subject: [PATCH] connect database --- pilot/configs/model_config.py | 8 +++++ pilot/connections/mysql_conn.py | 20 +++++++---- pilot/conversation.py | 24 +++++++++++-- pilot/server/webserver.py | 61 ++++++++++++++++++++++++--------- 4 files changed, 87 insertions(+), 26 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 80b1dabe4..ad8eb83c1 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -25,3 +25,11 @@ vicuna_model_server = "http://192.168.31.114:8000" # Load model config isload_8bit = True isdebug = False + + +DB_SETTINGS = { + "user": "root", + "password": "********", + "host": "localhost", + "port": 3306 +} \ No newline at end of file diff --git a/pilot/connections/mysql_conn.py b/pilot/connections/mysql_conn.py index 3825d425f..2dfff2ee7 100644 --- a/pilot/connections/mysql_conn.py +++ b/pilot/connections/mysql_conn.py @@ -5,6 +5,8 @@ import pymysql class MySQLOperator: """Connect MySQL Database fetch MetaData For LLM Prompt """ + + default_db = ["information_schema", "performance_schema", "sys", "mysql"] def __init__(self, user, password, host="localhost", port=3306) -> None: self.conn = pymysql.connect( @@ -25,12 +27,16 @@ class MySQLOperator: results = cursor.fetchall() return results + def get_db_list(self): + with self.conn.cursor() as cursor: + _sql = """ + show databases; + """ + cursor.execute(_sql) + results = cursor.fetchall() + + dbs = [d["Database"] for d in results if d["Database"] not in self.default_db] + return dbs + -if __name__ == "__main__": - mo = MySQLOperator( - "root", - "aa123456", - ) - schema = mo.get_schema("dbgpt") - print(schema) \ No newline at end of file diff --git a/pilot/conversation.py b/pilot/conversation.py index 8e172e7dd..e88ceaccb 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -4,7 +4,7 @@ import dataclasses from enum import auto, Enum from typing import List, Any - +from pilot.configs.model_config import DB_SETTINGS class SeparatorStyle(Enum): @@ -88,6 +88,19 @@ class Conversation: } +def gen_sqlgen_conversation(dbname): + from pilot.connections.mysql_conn import MySQLOperator + mo = MySQLOperator( + **DB_SETTINGS + ) + + message = "" + + schemas = mo.get_schema(dbname) + for s in schemas: + message += s["schema_info"] + ";" + return f"数据库{dbname}的Schema信息如下: {message}\n" + conv_one_shot = Conversation( system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " "The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", @@ -121,7 +134,7 @@ conv_one_shot = Conversation( sep_style=SeparatorStyle.SINGLE, sep="###" ) - + conv_vicuna_v1 = Conversation( system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. " "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", @@ -137,5 +150,10 @@ default_conversation = conv_one_shot conv_templates = { "conv_one_shot": conv_one_shot, - "vicuna_v1": conv_vicuna_v1 + "vicuna_v1": conv_vicuna_v1, } + + +if __name__ == "__main__": + message = gen_sqlgen_conversation("dbgpt") + print(message) \ No newline at end of file diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index f620b2d8d..c13a5331f 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -10,6 +10,9 @@ import gradio as gr import datetime import requests from urllib.parse import urljoin +from pilot.configs.model_config import DB_SETTINGS +from pilot.connections.mysql_conn import MySQLOperator + from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL @@ -29,7 +32,7 @@ from fastchat.utils import ( from fastchat.serve.gradio_patch import Chatbot as grChatbot from fastchat.serve.gradio_css import code_highlight_css -logger = build_logger("webserver", "webserver.log") +logger = build_logger("webserver", LOGDIR + "webserver.log") headers = {"User-Agent": "dbgpt Client"} no_change_btn = gr.Button.update() @@ -38,11 +41,28 @@ disable_btn = gr.Button.update(interactive=True) enable_moderation = False models = [] +dbs = [] priority = { "vicuna-13b": "aaa" } +def gen_sqlgen_conversation(dbname): + mo = MySQLOperator( + **DB_SETTINGS + ) + + message = "" + + schemas = mo.get_schema(dbname) + for s in schemas: + message += s["schema_info"] + ";" + return f"数据库{dbname}的Schema信息如下: {message}\n" + +def get_database_list(): + mo = MySQLOperator(**DB_SETTINGS) + return mo.get_db_list() + get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); @@ -58,12 +78,10 @@ function() { def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + dbs = get_database_list() dropdown_update = gr.Dropdown.update(visible=True) - if "model" in url_params: - model = url_params["model"] - if model in models: - dropdown_update = gr.Dropdown.update( - value=model, visible=True) + if dbs: + gr.Dropdown.update(choices=dbs) state = default_conversation.copy() return (state, @@ -120,10 +138,11 @@ def post_process_code(code): code = sep.join(blocks) return code -def http_bot(state, temperature, max_new_tokens, request: gr.Request): +def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request): start_tstamp = time.time() model_name = LLM_MODEL + dbname = db_selector # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 if state.skip_next: # This generate call is skipped due to invalid inputs @@ -131,16 +150,20 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): return if len(state.messages) == state.offset + 2: - # First round of conversation + # 第一轮对话需要加入提示Prompt template_name = "conv_one_shot" new_state = conv_templates[template_name].copy() new_state.conv_id = uuid.uuid4().hex - new_state.append_message(new_state.roles[0], state.messages[-2][1]) + + # prompt 中添加上下文提示 + new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state - + + prompt = state.get_prompt() + skip_echo_len = len(prompt.replace("", " ")) + 1 # Make requests @@ -250,11 +273,16 @@ def build_single_model_ui(): with gr.Tabs(): with gr.TabItem("知识问答", elem_id="QA"): - pass - + pass with gr.TabItem("SQL生成与诊断", elem_id="SQL"): # TODO A selector to choose database - pass + with gr.Row(elem_id="db_selector"): + db_selector = gr.Dropdown( + label="请选择数据库", + choices=dbs, + value=dbs[0] if len(models) > 0 else "", + interactive=True, + show_label=True).style(container=False) with gr.Blocks(): chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) @@ -277,7 +305,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, temperature, max_output_tokens], + [state, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -286,7 +314,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, temperature, max_output_tokens], + [state, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -294,7 +322,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, temperature, max_output_tokens], + [state, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list ) @@ -351,6 +379,7 @@ if __name__ == "__main__": args = parser.parse_args() logger.info(f"args: {args}") + dbs = get_database_list() logger.info(args) demo = build_webdemo() demo.queue(