mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 22:37:31 +00:00
connect database
This commit is contained in:
parent
a71c8b6d56
commit
88a5c57646
@ -25,3 +25,11 @@ vicuna_model_server = "http://192.168.31.114:8000"
|
|||||||
# Load model config
|
# Load model config
|
||||||
isload_8bit = True
|
isload_8bit = True
|
||||||
isdebug = False
|
isdebug = False
|
||||||
|
|
||||||
|
|
||||||
|
DB_SETTINGS = {
|
||||||
|
"user": "root",
|
||||||
|
"password": "********",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 3306
|
||||||
|
}
|
@ -5,6 +5,8 @@ import pymysql
|
|||||||
|
|
||||||
class MySQLOperator:
|
class MySQLOperator:
|
||||||
"""Connect MySQL Database fetch MetaData For LLM Prompt """
|
"""Connect MySQL Database fetch MetaData For LLM Prompt """
|
||||||
|
|
||||||
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
def __init__(self, user, password, host="localhost", port=3306) -> None:
|
def __init__(self, user, password, host="localhost", port=3306) -> None:
|
||||||
|
|
||||||
self.conn = pymysql.connect(
|
self.conn = pymysql.connect(
|
||||||
@ -25,12 +27,16 @@ class MySQLOperator:
|
|||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def get_db_list(self):
|
||||||
|
with self.conn.cursor() as cursor:
|
||||||
|
_sql = """
|
||||||
|
show databases;
|
||||||
|
"""
|
||||||
|
cursor.execute(_sql)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
dbs = [d["Database"] for d in results if d["Database"] not in self.default_db]
|
||||||
|
return dbs
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mo = MySQLOperator(
|
|
||||||
"root",
|
|
||||||
"aa123456",
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = mo.get_schema("dbgpt")
|
|
||||||
print(schema)
|
|
@ -4,7 +4,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
|
from pilot.configs.model_config import DB_SETTINGS
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
class SeparatorStyle(Enum):
|
||||||
|
|
||||||
@ -88,6 +88,19 @@ class Conversation:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def gen_sqlgen_conversation(dbname):
|
||||||
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
|
mo = MySQLOperator(
|
||||||
|
**DB_SETTINGS
|
||||||
|
)
|
||||||
|
|
||||||
|
message = ""
|
||||||
|
|
||||||
|
schemas = mo.get_schema(dbname)
|
||||||
|
for s in schemas:
|
||||||
|
message += s["schema_info"] + ";"
|
||||||
|
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||||
|
|
||||||
conv_one_shot = Conversation(
|
conv_one_shot = Conversation(
|
||||||
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||||
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
|
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
|
||||||
@ -121,7 +134,7 @@ conv_one_shot = Conversation(
|
|||||||
sep_style=SeparatorStyle.SINGLE,
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
sep="###"
|
sep="###"
|
||||||
)
|
)
|
||||||
|
|
||||||
conv_vicuna_v1 = Conversation(
|
conv_vicuna_v1 = Conversation(
|
||||||
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||||
@ -137,5 +150,10 @@ default_conversation = conv_one_shot
|
|||||||
|
|
||||||
conv_templates = {
|
conv_templates = {
|
||||||
"conv_one_shot": conv_one_shot,
|
"conv_one_shot": conv_one_shot,
|
||||||
"vicuna_v1": conv_vicuna_v1
|
"vicuna_v1": conv_vicuna_v1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
message = gen_sqlgen_conversation("dbgpt")
|
||||||
|
print(message)
|
@ -10,6 +10,9 @@ import gradio as gr
|
|||||||
import datetime
|
import datetime
|
||||||
import requests
|
import requests
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
from pilot.configs.model_config import DB_SETTINGS
|
||||||
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
|
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||||
|
|
||||||
@ -29,7 +32,7 @@ from fastchat.utils import (
|
|||||||
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||||
from fastchat.serve.gradio_css import code_highlight_css
|
from fastchat.serve.gradio_css import code_highlight_css
|
||||||
|
|
||||||
logger = build_logger("webserver", "webserver.log")
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
|
|
||||||
no_change_btn = gr.Button.update()
|
no_change_btn = gr.Button.update()
|
||||||
@ -38,11 +41,28 @@ disable_btn = gr.Button.update(interactive=True)
|
|||||||
|
|
||||||
enable_moderation = False
|
enable_moderation = False
|
||||||
models = []
|
models = []
|
||||||
|
dbs = []
|
||||||
|
|
||||||
priority = {
|
priority = {
|
||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def gen_sqlgen_conversation(dbname):
|
||||||
|
mo = MySQLOperator(
|
||||||
|
**DB_SETTINGS
|
||||||
|
)
|
||||||
|
|
||||||
|
message = ""
|
||||||
|
|
||||||
|
schemas = mo.get_schema(dbname)
|
||||||
|
for s in schemas:
|
||||||
|
message += s["schema_info"] + ";"
|
||||||
|
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||||
|
|
||||||
|
def get_database_list():
|
||||||
|
mo = MySQLOperator(**DB_SETTINGS)
|
||||||
|
return mo.get_db_list()
|
||||||
|
|
||||||
get_window_url_params = """
|
get_window_url_params = """
|
||||||
function() {
|
function() {
|
||||||
const params = new URLSearchParams(window.location.search);
|
const params = new URLSearchParams(window.location.search);
|
||||||
@ -58,12 +78,10 @@ function() {
|
|||||||
def load_demo(url_params, request: gr.Request):
|
def load_demo(url_params, request: gr.Request):
|
||||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||||
|
|
||||||
|
dbs = get_database_list()
|
||||||
dropdown_update = gr.Dropdown.update(visible=True)
|
dropdown_update = gr.Dropdown.update(visible=True)
|
||||||
if "model" in url_params:
|
if dbs:
|
||||||
model = url_params["model"]
|
gr.Dropdown.update(choices=dbs)
|
||||||
if model in models:
|
|
||||||
dropdown_update = gr.Dropdown.update(
|
|
||||||
value=model, visible=True)
|
|
||||||
|
|
||||||
state = default_conversation.copy()
|
state = default_conversation.copy()
|
||||||
return (state,
|
return (state,
|
||||||
@ -120,10 +138,11 @@ def post_process_code(code):
|
|||||||
code = sep.join(blocks)
|
code = sep.join(blocks)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
model_name = LLM_MODEL
|
model_name = LLM_MODEL
|
||||||
|
|
||||||
|
dbname = db_selector
|
||||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||||
if state.skip_next:
|
if state.skip_next:
|
||||||
# This generate call is skipped due to invalid inputs
|
# This generate call is skipped due to invalid inputs
|
||||||
@ -131,16 +150,20 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
# First round of conversation
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
|
||||||
template_name = "conv_one_shot"
|
template_name = "conv_one_shot"
|
||||||
new_state = conv_templates[template_name].copy()
|
new_state = conv_templates[template_name].copy()
|
||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
|
||||||
|
# prompt 中添加上下文提示
|
||||||
|
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + state.messages[-2][1])
|
||||||
new_state.append_message(new_state.roles[1], None)
|
new_state.append_message(new_state.roles[1], None)
|
||||||
state = new_state
|
state = new_state
|
||||||
|
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
# Make requests
|
# Make requests
|
||||||
@ -250,11 +273,16 @@ def build_single_model_ui():
|
|||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.TabItem("知识问答", elem_id="QA"):
|
with gr.TabItem("知识问答", elem_id="QA"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
||||||
# TODO A selector to choose database
|
# TODO A selector to choose database
|
||||||
pass
|
with gr.Row(elem_id="db_selector"):
|
||||||
|
db_selector = gr.Dropdown(
|
||||||
|
label="请选择数据库",
|
||||||
|
choices=dbs,
|
||||||
|
value=dbs[0] if len(models) > 0 else "",
|
||||||
|
interactive=True,
|
||||||
|
show_label=True).style(container=False)
|
||||||
|
|
||||||
with gr.Blocks():
|
with gr.Blocks():
|
||||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||||
@ -277,7 +305,7 @@ def build_single_model_ui():
|
|||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, temperature, max_output_tokens],
|
[state, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||||
@ -286,7 +314,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, temperature, max_output_tokens],
|
[state, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -294,7 +322,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, temperature, max_output_tokens],
|
[state, db_selector, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list
|
[state, chatbot] + btn_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -351,6 +379,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
dbs = get_database_list()
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
demo.queue(
|
demo.queue(
|
||||||
|
Loading…
Reference in New Issue
Block a user