mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
impl gradio web
This commit is contained in:
parent
d73cab9c4b
commit
90dd34e25a
@ -1,13 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import gradio as gr
|
||||
import datetime
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||
|
||||
from pilot.conversation import (
|
||||
get_default_conv_template,
|
||||
@ -39,10 +43,41 @@ priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
def set_global_vars(enable_moderation_, models_):
|
||||
def set_global_vars(enable_moderation_):
|
||||
global enable_moderation, models
|
||||
enable_moderation = enable_moderation_
|
||||
models = models_
|
||||
|
||||
def load_demo_single(url_params):
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if "model" in url_params:
|
||||
model = url_params["model"]
|
||||
if model in models:
|
||||
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
||||
|
||||
state = None
|
||||
return (
|
||||
state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True),
|
||||
)
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
function() {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
url_params = Object.fromEntries(params);
|
||||
console.log(url_params);
|
||||
return url_params;
|
||||
}
|
||||
"""
|
||||
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
return load_demo_single(url_params)
|
||||
|
||||
def get_conv_log_filename():
|
||||
t = datetime.datetime.now()
|
||||
@ -94,11 +129,11 @@ def post_process_code(code):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||
logger.info(f"http_bot. ip: {request.client.host}")
|
||||
start_tstamp = time.time()
|
||||
|
||||
model_name = model_selector
|
||||
model_name = LLM_MODEL
|
||||
temperature = float(temperature)
|
||||
max_new_tokens = int(max_new_tokens)
|
||||
|
||||
@ -106,7 +141,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if len(state.message) == state.offset + 2:
|
||||
if len(state.messages) == state.offset + 2:
|
||||
new_state = get_default_conv_template(model_name).copy()
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||
@ -114,4 +149,238 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
||||
state = new_state
|
||||
|
||||
|
||||
# TODO
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = compute_skip_echo_len(prompt)
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
|
||||
}
|
||||
|
||||
logger.info(f"Request: \n {payload}")
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_model_server, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len].strip()
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
else:
|
||||
output = data["text"] + f" (error_code): {data['error_code']}"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn
|
||||
)
|
||||
|
||||
return
|
||||
time.sleep(0.02)
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn
|
||||
)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
with open(get_conv_log_filename(), "a") as flog:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"gen_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(finish_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
flog.write(json.dumps(data), + "\n")
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index
|
||||
,基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||
"""
|
||||
|
||||
state = gr.State()
|
||||
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||
|
||||
|
||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=20):
|
||||
textbox = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="Enter text and press ENTER",
|
||||
visible=False,
|
||||
).style(container=False)
|
||||
|
||||
with gr.Column(scale=2, min_width=50):
|
||||
send_btn = gr.Button(value="" "发送", visible=False)
|
||||
|
||||
with gr.Row(visible=False) as button_row:
|
||||
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
|
||||
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
|
||||
|
||||
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
|
||||
temperature = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
|
||||
textbox.submit(
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
send_btn.click(
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
)
|
||||
|
||||
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
||||
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
theme=gr.themes.Base(),
|
||||
# theme=gr.themes.Monochrome(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
state,
|
||||
chatbot,
|
||||
textbox,
|
||||
send_btn,
|
||||
button_row,
|
||||
parameter_row,
|
||||
) = build_single_model_ui()
|
||||
|
||||
if args.model_list_mode == "once":
|
||||
demo.load(
|
||||
load_demo,
|
||||
[url_params],
|
||||
[
|
||||
state,
|
||||
chatbot,
|
||||
textbox,
|
||||
send_btn,
|
||||
button_row,
|
||||
parameter_row,
|
||||
],
|
||||
_js=get_window_url_params,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--moderate", action="store_true", help="Enable content moderation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
set_global_vars(args.moderate)
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||
)
|
Loading…
Reference in New Issue
Block a user