impl gradio web

This commit is contained in:
csunny 2023-04-30 20:49:10 +08:00
parent d73cab9c4b
commit 90dd34e25a

View File

@ -1,13 +1,17 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import uuid
import json
import time
import gradio as gr
import datetime
import requests
from urllib.parse import urljoin
from pilot.configs.model_config import LOGDIR
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.conversation import (
get_default_conv_template,
@ -39,10 +43,41 @@ priority = {
"vicuna-13b": "aaa"
}
def set_global_vars(enable_moderation_, models_):
def set_global_vars(enable_moderation_):
global enable_moderation, models
enable_moderation = enable_moderation_
models = models_
def load_demo_single(url_params):
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
return load_demo_single(url_params)
def get_conv_log_filename():
t = datetime.datetime.now()
@ -94,11 +129,11 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
model_name = LLM_MODEL
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
@ -106,7 +141,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.message) == state.offset + 2:
if len(state.messages) == state.offset + 2:
new_state = get_default_conv_template(model_name).copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
@ -114,4 +149,238 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
state = new_state
# TODO
prompt = state.get_prompt()
skip_echo_len = compute_skip_echo_len(prompt)
payload = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
}
logger.info(f"Request: \n {payload}")
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
response = requests.post(
url=urljoin(vicuna_model_server, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=60,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code): {data['error_code']}"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
return
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as flog:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
flog.write(json.dumps(data), + "\n")
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def build_single_model_ui():
notice_markdown = """
# DB-GPT
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序它基于[FastChat](https://github.com/lm-sys/FastChat)并使用vicuna作为基础模型此外此程序结合了langchain和llama-index
,基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说它似乎是一个用于数据库的复杂且创新的AI工具如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题请告诉我我会尽力帮助您
"""
learn_more_markdown = """
### Licence
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
"""
state = gr.State()
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="" "发送", visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="最大输出Token数",
)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, chatbot] + btn_list
)
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
theme=gr.themes.Base(),
# theme=gr.themes.Monochrome(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui()
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params,
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument(
"--moderate", action="store_true", help="Enable content moderation"
)
args = parser.parse_args()
logger.info(f"args: {args}")
set_global_vars(args.moderate)
logger.info(args)
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
)