connect database

This commit is contained in:
csunny 2023-05-03 22:10:35 +08:00
parent a71c8b6d56
commit 88a5c57646
4 changed files with 87 additions and 26 deletions

View File

@ -25,3 +25,11 @@ vicuna_model_server = "http://192.168.31.114:8000"
# Load model config
isload_8bit = True
isdebug = False
DB_SETTINGS = {
"user": "root",
"password": "********",
"host": "localhost",
"port": 3306
}

View File

@ -5,6 +5,8 @@ import pymysql
class MySQLOperator:
"""Connect MySQL Database fetch MetaData For LLM Prompt """
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
def __init__(self, user, password, host="localhost", port=3306) -> None:
self.conn = pymysql.connect(
@ -25,12 +27,16 @@ class MySQLOperator:
results = cursor.fetchall()
return results
def get_db_list(self):
with self.conn.cursor() as cursor:
_sql = """
show databases;
"""
cursor.execute(_sql)
results = cursor.fetchall()
dbs = [d["Database"] for d in results if d["Database"] not in self.default_db]
return dbs
if __name__ == "__main__":
mo = MySQLOperator(
"root",
"aa123456",
)
schema = mo.get_schema("dbgpt")
print(schema)

View File

@ -4,7 +4,7 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any
from pilot.configs.model_config import DB_SETTINGS
class SeparatorStyle(Enum):
@ -88,6 +88,19 @@ class Conversation:
}
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql_conn import MySQLOperator
mo = MySQLOperator(
**DB_SETTINGS
)
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
@ -121,7 +134,7 @@ conv_one_shot = Conversation(
sep_style=SeparatorStyle.SINGLE,
sep="###"
)
conv_vicuna_v1 = Conversation(
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
@ -137,5 +150,10 @@ default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1
"vicuna_v1": conv_vicuna_v1,
}
if __name__ == "__main__":
message = gen_sqlgen_conversation("dbgpt")
print(message)

View File

@ -10,6 +10,9 @@ import gradio as gr
import datetime
import requests
from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
@ -29,7 +32,7 @@ from fastchat.utils import (
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
logger = build_logger("webserver", "webserver.log")
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
@ -38,11 +41,28 @@ disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
dbs = []
priority = {
"vicuna-13b": "aaa"
}
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
)
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list():
mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list()
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
@ -58,12 +78,10 @@ function() {
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
state = default_conversation.copy()
return (state,
@ -120,10 +138,11 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
start_tstamp = time.time()
model_name = LLM_MODEL
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
@ -131,16 +150,20 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
return
if len(state.messages) == state.offset + 2:
# First round of conversation
# 第一轮对话需要加入提示Prompt
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
# prompt 中添加上下文提示
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
@ -250,11 +273,16 @@ def build_single_model_ui():
with gr.Tabs():
with gr.TabItem("知识问答", elem_id="QA"):
pass
pass
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
# TODO A selector to choose database
pass
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label="请选择数据库",
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
@ -277,7 +305,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, temperature, max_output_tokens],
[state, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -286,7 +314,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@ -294,7 +322,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
)
@ -351,6 +379,7 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
dbs = get_database_list()
logger.info(args)
demo = build_webdemo()
demo.queue(