diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 147af8b3d..ae16ac9ba 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,13 +1,17 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import argparse import os import uuid +import json import time import gradio as gr import datetime +import requests +from urllib.parse import urljoin -from pilot.configs.model_config import LOGDIR +from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL from pilot.conversation import ( get_default_conv_template, @@ -39,10 +43,41 @@ priority = { "vicuna-13b": "aaa" } -def set_global_vars(enable_moderation_, models_): +def set_global_vars(enable_moderation_): global enable_moderation, models enable_moderation = enable_moderation_ - models = models_ + +def load_demo_single(url_params): + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update(value=model, visible=True) + + state = None + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return load_demo_single(url_params) def get_conv_log_filename(): t = datetime.datetime.now() @@ -94,11 +129,11 @@ def post_process_code(code): code = sep.join(blocks) return code -def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request): +def http_bot(state, temperature, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() - model_name = model_selector + model_name = LLM_MODEL temperature = float(temperature) max_new_tokens = int(max_new_tokens) @@ -106,7 +141,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return - if len(state.message) == state.offset + 2: + if len(state.messages) == state.offset + 2: new_state = get_default_conv_template(model_name).copy() new_state.conv_id = uuid.uuid4().hex new_state.append_message(new_state.roles[0], state.messages[-2][1]) @@ -114,4 +149,238 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req state = new_state - # TODO \ No newline at end of file + prompt = state.get_prompt() + skip_echo_len = compute_skip_echo_len(prompt) + + payload = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, + } + + logger.info(f"Request: \n {payload}") + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + response = requests.post( + url=urljoin(vicuna_model_server, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=60, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][skip_echo_len].strip() + output = post_process_code(output) + state.messages[-1][-1] = output + "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code): {data['error_code']}" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn + ) + + return + time.sleep(0.02) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn + ) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as flog: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + flog.write(json.dumps(data), + "\n") + +block_css = ( + code_highlight_css + + """ +pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ +} +#notice_markdown th { + display: none; +} + """ +) + + +def build_single_model_ui(): + + notice_markdown = """ + # DB-GPT + + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index + ,基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。 + """ + learn_more_markdown = """ + ### Licence + The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA + """ + + state = gr.State() + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + + chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + + with gr.Column(scale=2, min_width=50): + send_btn = gr.Button(value="" "发送", visible=False) + + with gr.Row(visible=False) as button_row: + regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False) + clear_btn = gr.Button(value="🗑️" "清理", interactive=False) + + with gr.Accordion("参数", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="最大输出Token数", + ) + + gr.Markdown(learn_more_markdown) + + + btn_list = [regenerate_btn, clear_btn] + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + + send_btn.click( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list + ) + + return state, chatbot, textbox, send_btn, button_row, parameter_row + + +def build_webdemo(): + with gr.Blocks( + title="数据库智能助手", + theme=gr.themes.Base(), + # theme=gr.themes.Monochrome(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + ( + state, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ) = build_single_model_ui() + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [ + state, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ], + _js=get_window_url_params, + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + return demo + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument( + "--model-list-mode", type=str, default="once", choices=["once", "reload"] + ) + parser.add_argument("--share", default=False, action="store_true") + parser.add_argument( + "--moderate", action="store_true", help="Enable content moderation" + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + set_global_vars(args.moderate) + + logger.info(args) + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200, + ) \ No newline at end of file