mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 11:47:44 +00:00
impl gradio web
This commit is contained in:
parent
d73cab9c4b
commit
90dd34e25a
@ -1,13 +1,17 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import datetime
|
import datetime
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
get_default_conv_template,
|
get_default_conv_template,
|
||||||
@ -39,10 +43,41 @@ priority = {
|
|||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_global_vars(enable_moderation_, models_):
|
def set_global_vars(enable_moderation_):
|
||||||
global enable_moderation, models
|
global enable_moderation, models
|
||||||
enable_moderation = enable_moderation_
|
enable_moderation = enable_moderation_
|
||||||
models = models_
|
|
||||||
|
def load_demo_single(url_params):
|
||||||
|
dropdown_update = gr.Dropdown.update(visible=True)
|
||||||
|
if "model" in url_params:
|
||||||
|
model = url_params["model"]
|
||||||
|
if model in models:
|
||||||
|
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
return (
|
||||||
|
state,
|
||||||
|
dropdown_update,
|
||||||
|
gr.Chatbot.update(visible=True),
|
||||||
|
gr.Textbox.update(visible=True),
|
||||||
|
gr.Button.update(visible=True),
|
||||||
|
gr.Row.update(visible=True),
|
||||||
|
gr.Accordion.update(visible=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
get_window_url_params = """
|
||||||
|
function() {
|
||||||
|
const params = new URLSearchParams(window.location.search);
|
||||||
|
url_params = Object.fromEntries(params);
|
||||||
|
console.log(url_params);
|
||||||
|
return url_params;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_demo(url_params, request: gr.Request):
|
||||||
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||||
|
return load_demo_single(url_params)
|
||||||
|
|
||||||
def get_conv_log_filename():
|
def get_conv_log_filename():
|
||||||
t = datetime.datetime.now()
|
t = datetime.datetime.now()
|
||||||
@ -94,11 +129,11 @@ def post_process_code(code):
|
|||||||
code = sep.join(blocks)
|
code = sep.join(blocks)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||||
logger.info(f"http_bot. ip: {request.client.host}")
|
logger.info(f"http_bot. ip: {request.client.host}")
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
|
|
||||||
model_name = model_selector
|
model_name = LLM_MODEL
|
||||||
temperature = float(temperature)
|
temperature = float(temperature)
|
||||||
max_new_tokens = int(max_new_tokens)
|
max_new_tokens = int(max_new_tokens)
|
||||||
|
|
||||||
@ -106,7 +141,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(state.message) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
new_state = get_default_conv_template(model_name).copy()
|
new_state = get_default_conv_template(model_name).copy()
|
||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||||
@ -114,4 +149,238 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|||||||
state = new_state
|
state = new_state
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
prompt = state.get_prompt()
|
||||||
|
skip_echo_len = compute_skip_echo_len(prompt)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Request: \n {payload}")
|
||||||
|
state.messages[-1][-1] = "▌"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url=urljoin(vicuna_model_server, "generate_stream"),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
stream=True,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode())
|
||||||
|
if data["error_code"] == 0:
|
||||||
|
output = data["text"][skip_echo_len].strip()
|
||||||
|
output = post_process_code(output)
|
||||||
|
state.messages[-1][-1] = output + "▌"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
else:
|
||||||
|
output = data["text"] + f" (error_code): {data['error_code']}"
|
||||||
|
state.messages[-1][-1] = output
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
enable_btn,
|
||||||
|
enable_btn
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
time.sleep(0.02)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
disable_btn,
|
||||||
|
enable_btn,
|
||||||
|
enable_btn
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
finish_tstamp = time.time()
|
||||||
|
logger.info(f"{output}")
|
||||||
|
|
||||||
|
with open(get_conv_log_filename(), "a") as flog:
|
||||||
|
data = {
|
||||||
|
"tstamp": round(finish_tstamp, 4),
|
||||||
|
"type": "chat",
|
||||||
|
"model": model_name,
|
||||||
|
"gen_params": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
},
|
||||||
|
"start": round(start_tstamp, 4),
|
||||||
|
"finish": round(finish_tstamp, 4),
|
||||||
|
"state": state.dict(),
|
||||||
|
"ip": request.client.host,
|
||||||
|
}
|
||||||
|
flog.write(json.dumps(data), + "\n")
|
||||||
|
|
||||||
|
block_css = (
|
||||||
|
code_highlight_css
|
||||||
|
+ """
|
||||||
|
pre {
|
||||||
|
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||||
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||||
|
white-space: -pre-wrap; /* Opera 4-6 */
|
||||||
|
white-space: -o-pre-wrap; /* Opera 7 */
|
||||||
|
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||||
|
}
|
||||||
|
#notice_markdown th {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_single_model_ui():
|
||||||
|
|
||||||
|
notice_markdown = """
|
||||||
|
# DB-GPT
|
||||||
|
|
||||||
|
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index
|
||||||
|
,基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。
|
||||||
|
"""
|
||||||
|
learn_more_markdown = """
|
||||||
|
### Licence
|
||||||
|
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||||
|
"""
|
||||||
|
|
||||||
|
state = gr.State()
|
||||||
|
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||||
|
|
||||||
|
|
||||||
|
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=20):
|
||||||
|
textbox = gr.Textbox(
|
||||||
|
show_label=False,
|
||||||
|
placeholder="Enter text and press ENTER",
|
||||||
|
visible=False,
|
||||||
|
).style(container=False)
|
||||||
|
|
||||||
|
with gr.Column(scale=2, min_width=50):
|
||||||
|
send_btn = gr.Button(value="" "发送", visible=False)
|
||||||
|
|
||||||
|
with gr.Row(visible=False) as button_row:
|
||||||
|
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
|
||||||
|
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
|
||||||
|
|
||||||
|
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
|
||||||
|
temperature = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.7,
|
||||||
|
step=0.1,
|
||||||
|
interactive=True,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_output_tokens = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=1024,
|
||||||
|
value=512,
|
||||||
|
step=64,
|
||||||
|
interactive=True,
|
||||||
|
label="最大输出Token数",
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Markdown(learn_more_markdown)
|
||||||
|
|
||||||
|
|
||||||
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
|
http_bot,
|
||||||
|
[state, temperature, max_output_tokens],
|
||||||
|
[state, chatbot] + btn_list,
|
||||||
|
)
|
||||||
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||||
|
|
||||||
|
textbox.submit(
|
||||||
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
|
).then(
|
||||||
|
http_bot,
|
||||||
|
[state, temperature, max_output_tokens],
|
||||||
|
[state, chatbot] + btn_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
send_btn.click(
|
||||||
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
|
).then(
|
||||||
|
http_bot,
|
||||||
|
[state, temperature, max_output_tokens],
|
||||||
|
[state, chatbot] + btn_list
|
||||||
|
)
|
||||||
|
|
||||||
|
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
||||||
|
|
||||||
|
|
||||||
|
def build_webdemo():
|
||||||
|
with gr.Blocks(
|
||||||
|
title="数据库智能助手",
|
||||||
|
theme=gr.themes.Base(),
|
||||||
|
# theme=gr.themes.Monochrome(),
|
||||||
|
css=block_css,
|
||||||
|
) as demo:
|
||||||
|
url_params = gr.JSON(visible=False)
|
||||||
|
(
|
||||||
|
state,
|
||||||
|
chatbot,
|
||||||
|
textbox,
|
||||||
|
send_btn,
|
||||||
|
button_row,
|
||||||
|
parameter_row,
|
||||||
|
) = build_single_model_ui()
|
||||||
|
|
||||||
|
if args.model_list_mode == "once":
|
||||||
|
demo.load(
|
||||||
|
load_demo,
|
||||||
|
[url_params],
|
||||||
|
[
|
||||||
|
state,
|
||||||
|
chatbot,
|
||||||
|
textbox,
|
||||||
|
send_btn,
|
||||||
|
button_row,
|
||||||
|
parameter_row,
|
||||||
|
],
|
||||||
|
_js=get_window_url_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
||||||
|
return demo
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
|
parser.add_argument("--port", type=int)
|
||||||
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||||
|
)
|
||||||
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--moderate", action="store_true", help="Enable content moderation"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
set_global_vars(args.moderate)
|
||||||
|
|
||||||
|
logger.info(args)
|
||||||
|
demo = build_webdemo()
|
||||||
|
demo.queue(
|
||||||
|
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||||
|
).launch(
|
||||||
|
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user