mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 04:53:36 +00:00
fix
This commit is contained in:
parent
e132980127
commit
ca50c9fe47
19
examples/gradio_test.py
Normal file
19
examples/gradio_test.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
def change_tab():
|
||||||
|
return gr.Tabs.update(selected=1)
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
with gr.Tabs() as tabs:
|
||||||
|
with gr.TabItem("Train", id=0):
|
||||||
|
t = gr.Textbox()
|
||||||
|
with gr.TabItem("Inference", id=1):
|
||||||
|
i = gr.Image()
|
||||||
|
|
||||||
|
btn = gr.Button()
|
||||||
|
btn.click(change_tab, None, tabs)
|
||||||
|
|
||||||
|
demo.launch()
|
@ -28,8 +28,7 @@ VECTOR_SEARCH_TOP_K = 3
|
|||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 4096
|
MAX_POSITION_EMBEDDINGS = 4096
|
||||||
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
|
VICUNA_MODEL_SERVER = "http://47.97.125.199:8000"
|
||||||
|
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
ISLOAD_8BIT = True
|
ISLOAD_8BIT = True
|
||||||
|
@ -45,6 +45,7 @@ enable_moderation = False
|
|||||||
models = []
|
models = []
|
||||||
dbs = []
|
dbs = []
|
||||||
vs_list = ["新建知识库"] + get_vector_storelist()
|
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||||
|
autogpt = False
|
||||||
|
|
||||||
priority = {
|
priority = {
|
||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
@ -58,8 +59,6 @@ def get_simlar(q):
|
|||||||
contents = [dc.page_content for dc, _ in docs]
|
contents = [dc.page_content for dc, _ in docs]
|
||||||
return "\n".join(contents)
|
return "\n".join(contents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
mo = MySQLOperator(
|
mo = MySQLOperator(
|
||||||
**DB_SETTINGS
|
**DB_SETTINGS
|
||||||
@ -118,6 +117,8 @@ def regenerate(state, request: gr.Request):
|
|||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
def clear_history(request: gr.Request):
|
def clear_history(request: gr.Request):
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"clear_history. ip: {request.client.host}")
|
logger.info(f"clear_history. ip: {request.client.host}")
|
||||||
state = None
|
state = None
|
||||||
return (state, [], "") + (disable_btn,) * 5
|
return (state, [], "") + (disable_btn,) * 5
|
||||||
@ -135,7 +136,7 @@ def add_text(state, text, request: gr.Request):
|
|||||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||||
no_change_btn,) * 5
|
no_change_btn,) * 5
|
||||||
|
|
||||||
text = text[:1536] # Hard cut-off
|
text = text[:4000] # Hard cut-off
|
||||||
state.append_message(state.roles[0], text)
|
state.append_message(state.roles[0], text)
|
||||||
state.append_message(state.roles[1], None)
|
state.append_message(state.roles[1], None)
|
||||||
state.skip_next = False
|
state.skip_next = False
|
||||||
@ -152,6 +153,8 @@ def post_process_code(code):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
|
|
||||||
|
print("是否是AUTO-GPT模式.", autogpt)
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
model_name = LLM_MODEL
|
model_name = LLM_MODEL
|
||||||
|
|
||||||
@ -163,6 +166,7 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
|
||||||
@ -264,15 +268,14 @@ pre {
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def change_tab(tab):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
|
|
||||||
|
def change_tab():
|
||||||
|
autogpt = True
|
||||||
|
|
||||||
def build_single_model_ui():
|
def build_single_model_ui():
|
||||||
|
|
||||||
@ -309,7 +312,8 @@ def build_single_model_ui():
|
|||||||
)
|
)
|
||||||
tabs= gr.Tabs()
|
tabs= gr.Tabs()
|
||||||
with tabs:
|
with tabs:
|
||||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||||
|
with tab_sql:
|
||||||
# TODO A selector to choose database
|
# TODO A selector to choose database
|
||||||
with gr.Row(elem_id="db_selector"):
|
with gr.Row(elem_id="db_selector"):
|
||||||
db_selector = gr.Dropdown(
|
db_selector = gr.Dropdown(
|
||||||
@ -318,9 +322,12 @@ def build_single_model_ui():
|
|||||||
value=dbs[0] if len(models) > 0 else "",
|
value=dbs[0] if len(models) > 0 else "",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
show_label=True).style(container=False)
|
show_label=True).style(container=False)
|
||||||
|
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
||||||
|
with tab_auto:
|
||||||
|
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||||
|
|
||||||
with gr.TabItem("知识问答", elem_id="QA"):
|
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||||
|
with tab_qa:
|
||||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
@ -360,9 +367,7 @@ def build_single_model_ui():
|
|||||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||||
clear_btn = gr.Button(value="清理", interactive=False)
|
clear_btn = gr.Button(value="清理", interactive=False)
|
||||||
|
|
||||||
|
|
||||||
gr.Markdown(learn_more_markdown)
|
gr.Markdown(learn_more_markdown)
|
||||||
|
|
||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
|
Loading…
Reference in New Issue
Block a user