diff --git a/README.md b/README.md index 370b035c5..97ce3b9bf 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Overall, it appears to be a sophisticated and innovative tool for working with d 1. Run model server ``` cd pilot/server -uvicorn vicuna_server:app --host 0.0.0.0 +python vicuna_server.py ``` 2. Run gradio webui diff --git a/environment.yml b/environment.yml index 3ec4dfd98..3ba070e56 100644 --- a/environment.yml +++ b/environment.yml @@ -60,3 +60,4 @@ dependencies: - gradio==3.24.1 - gradio-client==0.0.8 - wandb + - fschat=0.1.10 diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 2b9ae6df0..80b1dabe4 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -7,6 +7,7 @@ import os root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) model_path = os.path.join(root_path, "models") vector_storepath = os.path.join(root_path, "vector_store") +LOGDIR = os.path.join(root_path, "logs") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -16,9 +17,9 @@ llm_model_config = { } LLM_MODEL = "vicuna-13b" - - -vicuna_model_server = "http://192.168.31.114:8000/" +LIMIT_MODEL_CONCURRENCY = 5 +MAX_POSITION_EMBEDDINGS = 2048 +vicuna_model_server = "http://192.168.31.114:8000" # Load model config diff --git a/pilot/conversation.py b/pilot/conversation.py new file mode 100644 index 000000000..8e172e7dd --- /dev/null +++ b/pilot/conversation.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import dataclasses +from enum import auto, Enum +from typing import List, Any + + +class SeparatorStyle(Enum): + + SINGLE = auto() + TWO = auto() + +@dataclasses.dataclass +class Conversation: + """This class keeps all conversation history. """ + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + # Used for gradio server + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":" + message + seps[i % 2] + else: + ret += role + ":" + return ret + + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id + } + + +conv_one_shot = Conversation( + system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " + "The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "What are the key differences between mysql and postgres?", + ), + ( + "Assistant", + "MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) " + "that have many similarities but also some differences. Here are some key differences: \n" + "1. Data Types: PostgreSQL has a more extensive set of data types, " + "including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n" + "2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), " + "but PostgreSQL is generally considered to be more strict in enforcing it.\n" + "3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers," + "whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n" + "4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, " + "whereas PostgreSQL is known for its robustness and reliability.\n" + "5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, " + "whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n" + + "Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. " + "Both are excellent database management systems, and choosing the right one " + "for your project requires careful consideration of your application's requirements, performance needs, and scalability." + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###" +) + +conv_vicuna_v1 = Conversation( + system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. " + "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +default_conversation = conv_one_shot + +conv_templates = { + "conv_one_shot": conv_one_shot, + "vicuna_v1": conv_vicuna_v1 +} diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 426043aa5..60d443f95 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048): temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 256)) stop_parameter = params.get("stop", None) - - print(tokenizer.__dir__()) if stop_parameter == tokenizer.eos_token: stop_parameter = None - stop_strings = [] if isinstance(stop_parameter, str): stop_strings.append(stop_parameter) @@ -43,33 +40,34 @@ def generate_output(model, tokenizer, params, device, context_len=2048): past_key_values=past_key_values, ) logits = out.logits - past_key_values = out.past_key_value + past_key_values = out.past_key_values + last_token_logits = logits[0][-1] if temperature < 1e-4: token = int(torch.argmax(last_token_logits)) else: - probs = torch.softmax(last_token_logits / temperature, dim=1) + probs = torch.softmax(last_token_logits / temperature, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) - + output_ids.append(token) if token == tokenizer.eos_token_id: stopped = True else: stopped = False - + output = tokenizer.decode(output_ids, skip_special_tokens=True) for stop_str in stop_strings: pos = output.rfind(stop_str) if pos != -1: output = output[:pos] - stoppped = True + stopped = True break else: pass - - if stoppped: + + if stopped: break del past_key_values @@ -81,7 +79,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048): @torch.inference_mode() def get_embeddings(model, tokenizer, prompt): input_ids = tokenizer(prompt).input_ids - input_embeddings = model.get_input_embeddings() - embeddings = input_embeddings(torch.LongTensor([input_ids])) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_embeddings = model.get_input_embeddings().to(device) + + embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device)) mean = torch.mean(embeddings[0], 0).cpu().detach() - return mean \ No newline at end of file + return mean.to(device) diff --git a/pilot/server/embdserver.py b/pilot/server/embdserver.py new file mode 100644 index 000000000..97206f2d5 --- /dev/null +++ b/pilot/server/embdserver.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 20ed928d0..f664410e8 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -1,16 +1,34 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import uvicorn +import asyncio +import json from typing import Optional, List -from fastapi import FastAPI +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +from fastchat.serve.inference import generate_stream from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings +from fastchat.serve.inference import load_model from pilot.model.loader import ModerLoader from pilot.configs.model_config import * model_path = llm_model_config[LLM_MODEL] -ml = ModerLoader(model_path=model_path) -model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) + + +global_counter = 0 +model_semaphore = None + +# ml = ModerLoader(model_path=model_path) +# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) +model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) + +class ModelWorker: + def __init__(self): + pass + + # TODO app = FastAPI() @@ -21,9 +39,58 @@ class PromptRequest(BaseModel): stop: Optional[List[str]] = None +class StreamRequest(BaseModel): + model: str + prompt: str + temperature: float + max_new_tokens: int + stop: str + class EmbeddingRequest(BaseModel): prompt: str +def release_model_semaphore(): + model_semaphore.release() + + +def generate_stream_gate(params): + try: + for output in generate_stream( + model, + tokenizer, + params, + DEVICE, + MAX_POSITION_EMBEDDINGS, + ): + print("output: ", output) + ret = { + "text": output, + "error_code": 0, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError: + ret = { + "text": "**GPU OutOfMemory, Please Refresh.**", + "error_code": 0 + } + yield json.dumps(ret).encode() + b"\0" + + +@app.post("/generate_stream") +async def api_generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + print(model, tokenizer, params, DEVICE) + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() + + generator = generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(generator, background=background_tasks) @app.post("/generate") def generate(prompt_request: PromptRequest): @@ -45,4 +112,9 @@ def embeddings(prompt_request: EmbeddingRequest): params = {"prompt": prompt_request.prompt} print("Received prompt: ", params["prompt"]) output = get_embeddings(model, tokenizer, params["prompt"]) - return {"response": [float(x) for x in output]} \ No newline at end of file + return {"response": [float(x) for x in output]} + + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", log_level="info") \ No newline at end of file diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py new file mode 100644 index 000000000..19ee4b697 --- /dev/null +++ b/pilot/server/webserver.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import os +import uuid +import json +import time +import gradio as gr +import datetime +import requests +from urllib.parse import urljoin + +from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL + +from pilot.conversation import ( + default_conversation, + conv_templates, + SeparatorStyle +) + +from fastchat.utils import ( + build_logger, + server_error_msg, + violates_moderation, + moderation_msg +) + +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_css import code_highlight_css + +logger = build_logger("webserver", "webserver.log") +headers = {"User-Agent": "dbgpt Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=True) + +enable_moderation = False +models = [] + +priority = { + "vicuna-13b": "aaa" +} + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + gradioURL = window.location.href + if (!gradioURL.endsWith('?__theme=dark')) { + window.location.replace(gradioURL + '?__theme=dark'); + } + return url_params; + } +""" +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return (state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True)) + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def regenerate(state, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + state.skip_next = False + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = None + return (state, [], "") + (disable_btn,) * 5 + + +def add_text(state, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 + if args.moderate: + flagged = violates_moderation(text) + if flagged: + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg) + ( + no_change_btn,) * 5 + + text = text[:1536] # Hard cut-off + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + +def post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + +def http_bot(state, temperature, max_new_tokens, request: gr.Request): + start_tstamp = time.time() + model_name = LLM_MODEL + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.messages) == state.offset + 2: + # First round of conversation + + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() + new_state.conv_id = uuid.uuid4().hex + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + prompt = state.get_prompt() + skip_echo_len = len(prompt.replace("", " ")) + 1 + + # Make requests + payload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "max_new_tokens": int(max_new_tokens), + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, + } + logger.info(f"Requert: \n{payload}") + + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + # Stream output + response = requests.post(urljoin(vicuna_model_server, "generate_stream"), + headers=headers, json=payload, stream=True, timeout=20) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][skip_echo_len:].strip() + output = post_process_code(output) + state.messages[-1][-1] = output + "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + time.sleep(0.02) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +block_css = ( + code_highlight_css + + """ +pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ +} +#notice_markdown th { + display: none; +} + """ +) + + +def build_single_model_ui(): + + notice_markdown = """ + # DB-GPT + + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna-13b作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。 + """ + learn_more_markdown = """ + ### Licence + The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA + """ + + state = gr.State() + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Accordion("参数", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="最大输出Token数", + ) + + chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + + with gr.Column(scale=2, min_width=50): + send_btn = gr.Button(value="" "发送", visible=False) + + + with gr.Row(visible=False) as button_row: + regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False) + clear_btn = gr.Button(value="🗑️" "清理", interactive=False) + + gr.Markdown(learn_more_markdown) + + btn_list = [regenerate_btn, clear_btn] + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + + send_btn.click( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list + ) + + return state, chatbot, textbox, send_btn, button_row, parameter_row + + +def build_webdemo(): + with gr.Blocks( + title="数据库智能助手", + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + ( + state, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ) = build_single_model_ui() + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [ + state, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ], + _js=get_window_url_params, + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + return demo + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument( + "--model-list-mode", type=str, default="once", choices=["once", "reload"] + ) + parser.add_argument("--share", default=False, action="store_true") + parser.add_argument( + "--moderate", action="store_true", help="Enable content moderation" + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + logger.info(args) + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200, + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dd7bf5189..50354b3b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,4 +49,5 @@ umap-learn notebook gradio==3.24.1 gradio-client==0.0.8 -wandb \ No newline at end of file +wandb +fschat=0.1.10 \ No newline at end of file