mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
connect database
This commit is contained in:
parent
a71c8b6d56
commit
88a5c57646
@ -25,3 +25,11 @@ vicuna_model_server = "http://192.168.31.114:8000"
|
||||
# Load model config
|
||||
isload_8bit = True
|
||||
isdebug = False
|
||||
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": "root",
|
||||
"password": "********",
|
||||
"host": "localhost",
|
||||
"port": 3306
|
||||
}
|
@ -5,6 +5,8 @@ import pymysql
|
||||
|
||||
class MySQLOperator:
|
||||
"""Connect MySQL Database fetch MetaData For LLM Prompt """
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
def __init__(self, user, password, host="localhost", port=3306) -> None:
|
||||
|
||||
self.conn = pymysql.connect(
|
||||
@ -25,12 +27,16 @@ class MySQLOperator:
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def get_db_list(self):
|
||||
with self.conn.cursor() as cursor:
|
||||
_sql = """
|
||||
show databases;
|
||||
"""
|
||||
cursor.execute(_sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
dbs = [d["Database"] for d in results if d["Database"] not in self.default_db]
|
||||
return dbs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mo = MySQLOperator(
|
||||
"root",
|
||||
"aa123456",
|
||||
)
|
||||
|
||||
schema = mo.get_schema("dbgpt")
|
||||
print(schema)
|
@ -4,7 +4,7 @@
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
|
||||
from pilot.configs.model_config import DB_SETTINGS
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
|
||||
@ -88,6 +88,19 @@ class Conversation:
|
||||
}
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
from pilot.connections.mysql_conn import MySQLOperator
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
)
|
||||
|
||||
message = ""
|
||||
|
||||
schemas = mo.get_schema(dbname)
|
||||
for s in schemas:
|
||||
message += s["schema_info"] + ";"
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
|
||||
@ -121,7 +134,7 @@ conv_one_shot = Conversation(
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###"
|
||||
)
|
||||
|
||||
|
||||
conv_vicuna_v1 = Conversation(
|
||||
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
@ -137,5 +150,10 @@ default_conversation = conv_one_shot
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1
|
||||
"vicuna_v1": conv_vicuna_v1,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
message = gen_sqlgen_conversation("dbgpt")
|
||||
print(message)
|
@ -10,6 +10,9 @@ import gradio as gr
|
||||
import datetime
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
from pilot.configs.model_config import DB_SETTINGS
|
||||
from pilot.connections.mysql_conn import MySQLOperator
|
||||
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||
|
||||
@ -29,7 +32,7 @@ from fastchat.utils import (
|
||||
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||
from fastchat.serve.gradio_css import code_highlight_css
|
||||
|
||||
logger = build_logger("webserver", "webserver.log")
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
no_change_btn = gr.Button.update()
|
||||
@ -38,11 +41,28 @@ disable_btn = gr.Button.update(interactive=True)
|
||||
|
||||
enable_moderation = False
|
||||
models = []
|
||||
dbs = []
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
)
|
||||
|
||||
message = ""
|
||||
|
||||
schemas = mo.get_schema(dbname)
|
||||
for s in schemas:
|
||||
message += s["schema_info"] + ";"
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
def get_database_list():
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
return mo.get_db_list()
|
||||
|
||||
get_window_url_params = """
|
||||
function() {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
@ -58,12 +78,10 @@ function() {
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
|
||||
dbs = get_database_list()
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if "model" in url_params:
|
||||
model = url_params["model"]
|
||||
if model in models:
|
||||
dropdown_update = gr.Dropdown.update(
|
||||
value=model, visible=True)
|
||||
if dbs:
|
||||
gr.Dropdown.update(choices=dbs)
|
||||
|
||||
state = default_conversation.copy()
|
||||
return (state,
|
||||
@ -120,10 +138,11 @@ def post_process_code(code):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
# This generate call is skipped due to invalid inputs
|
||||
@ -131,16 +150,20 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||
return
|
||||
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# First round of conversation
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||
|
||||
# prompt 中添加上下文提示
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1])
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
|
||||
|
||||
prompt = state.get_prompt()
|
||||
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
# Make requests
|
||||
@ -250,11 +273,16 @@ def build_single_model_ui():
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("知识问答", elem_id="QA"):
|
||||
pass
|
||||
|
||||
pass
|
||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
||||
# TODO A selector to choose database
|
||||
pass
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
db_selector = gr.Dropdown(
|
||||
label="请选择数据库",
|
||||
choices=dbs,
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
|
||||
with gr.Blocks():
|
||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||
@ -277,7 +305,7 @@ def build_single_model_ui():
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@ -286,7 +314,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@ -294,7 +322,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
)
|
||||
|
||||
@ -351,6 +379,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
dbs = get_database_list()
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
|
Loading…
Reference in New Issue
Block a user