From 4381e21d0c0b891c6e9ef39aa28a644e67946645 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:09:13 +0800 Subject: [PATCH 01/17] add conversation prompt --- pilot/conversation.py | 151 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 pilot/conversation.py diff --git a/pilot/conversation.py b/pilot/conversation.py new file mode 100644 index 000000000..4db5d9548 --- /dev/null +++ b/pilot/conversation.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any + + +class SeparatorStyle(Enum): + + SINGLE = auto() + TWO = auto() + +@dataclasses.dataclass +class Conversation: + """This class keeps all conversation history. """ + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + # Used for gradio server + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += self.sep + " " + role + ": " + message + else: + ret += self.sep + " " + role + ":" + return ret + + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":" + message + seps[i % 2] + else: + ret += role + ":" + return ret + + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id + } + + +conv_one_shot = Conversation( + system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " + "The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "What are the key differences between mysql and postgres?", + ), + ( + "Assistant", + "MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) " + "that have many similarities but also some differences. Here are some key differences: \n" + "1. Data Types: PostgreSQL has a more extensive set of data types, " + "including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n" + "2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), " + "but PostgreSQL is generally considered to be more strict in enforcing it.\n" + "3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers," + "whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n" + "4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, " + "whereas PostgreSQL is known for its robustness and reliability.\n" + "5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, " + "whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n" + + "Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. " + "Both are excellent database management systems, and choosing the right one " + "for your project requires careful consideration of your application's requirements, performance needs, and scalability." + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###" +) + +conv_vicuna_v1 = Conversation( + system = "A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " + "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_template = { + "conv_one_shot": conv_one_shot, + "vicuna_v1": conv_vicuna_v1 +} + + +def get_default_conv_template(model_name: str = "vicuna-13b"): + model_name = model_name.lower() + if "vicuna" in model_name: + return conv_vicuna_v1 + return conv_one_shot + + +def compute_skip_echo_len(prompt): + skip_echo_len = len(prompt) + 1 - prompt.count("") * 3 + return skip_echo_len \ No newline at end of file From 7c69dc248a0b847b3bc4ee05f81660d165e3345a Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:21:58 +0800 Subject: [PATCH 02/17] checkout load_model to fastchat --- pilot/model/inference.py | 3 ++- pilot/server/vicuna_server.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 426043aa5..4a013026f 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -84,4 +84,5 @@ def get_embeddings(model, tokenizer, prompt): input_embeddings = model.get_input_embeddings() embeddings = input_embeddings(torch.LongTensor([input_ids])) mean = torch.mean(embeddings[0], 0).cpu().detach() - return mean \ No newline at end of file + return mean + diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 20ed928d0..9fa5ddbee 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -1,17 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import uvicorn from typing import Optional, List from fastapi import FastAPI from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings +from fastchat.serve.inference import load_model from pilot.model.loader import ModerLoader from pilot.configs.model_config import * model_path = llm_model_config[LLM_MODEL] -ml = ModerLoader(model_path=model_path) -model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) +# ml = ModerLoader(model_path=model_path) +# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) +model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) app = FastAPI() class PromptRequest(BaseModel): @@ -45,4 +48,9 @@ def embeddings(prompt_request: EmbeddingRequest): params = {"prompt": prompt_request.prompt} print("Received prompt: ", params["prompt"]) output = get_embeddings(model, tokenizer, params["prompt"]) - return {"response": [float(x) for x in output]} \ No newline at end of file + return {"response": [float(x) for x in output]} + + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", log_level="info") \ No newline at end of file From bb6c1865e164cb06aeebb49910436b4e6569696e Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:22:45 +0800 Subject: [PATCH 03/17] update readme file --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 370b035c5..97ce3b9bf 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Overall, it appears to be a sophisticated and innovative tool for working with d 1. Run model server ``` cd pilot/server -uvicorn vicuna_server:app --host 0.0.0.0 +python vicuna_server.py ``` 2. Run gradio webui From e493f5480421ec73bf297b3a75945c7bce9335e3 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:47:33 +0800 Subject: [PATCH 04/17] fix multi device error --- pilot/model/inference.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 4a013026f..b15e1d749 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048): @torch.inference_mode() def get_embeddings(model, tokenizer, prompt): input_ids = tokenizer(prompt).input_ids - input_embeddings = model.get_input_embeddings() - embeddings = input_embeddings(torch.LongTensor([input_ids])) - mean = torch.mean(embeddings[0], 0).cpu().detach() - return mean + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_embeddings = model.get_input_embeddings().to(device) + embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device)) + mean = torch.mean(embeddings[0], 0).cpu().detach() + return mean.to(device) From 909045c0d69b508e2b27ad2353d27784a29fcefb Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 14:52:03 +0800 Subject: [PATCH 05/17] update --- pilot/model/inference.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index b15e1d749..60d443f95 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048): temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 256)) stop_parameter = params.get("stop", None) - - print(tokenizer.__dir__()) if stop_parameter == tokenizer.eos_token: stop_parameter = None - stop_strings = [] if isinstance(stop_parameter, str): stop_strings.append(stop_parameter) @@ -43,33 +40,34 @@ def generate_output(model, tokenizer, params, device, context_len=2048): past_key_values=past_key_values, ) logits = out.logits - past_key_values = out.past_key_value + past_key_values = out.past_key_values + last_token_logits = logits[0][-1] if temperature < 1e-4: token = int(torch.argmax(last_token_logits)) else: - probs = torch.softmax(last_token_logits / temperature, dim=1) + probs = torch.softmax(last_token_logits / temperature, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) - + output_ids.append(token) if token == tokenizer.eos_token_id: stopped = True else: stopped = False - + output = tokenizer.decode(output_ids, skip_special_tokens=True) for stop_str in stop_strings: pos = output.rfind(stop_str) if pos != -1: output = output[:pos] - stoppped = True + stopped = True break else: pass - - if stoppped: + + if stopped: break del past_key_values From d73cab9c4b4f269e86be343daff4dee9651ce66f Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 16:05:31 +0800 Subject: [PATCH 06/17] add stream output --- pilot/configs/model_config.py | 5 +- pilot/server/embdserver.py | 3 + pilot/server/vicuna_server.py | 52 ++++++++++++++- pilot/server/webserver.py | 117 ++++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 pilot/server/embdserver.py create mode 100644 pilot/server/webserver.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 2b9ae6df0..cf26ff7e1 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -7,6 +7,7 @@ import os root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) model_path = os.path.join(root_path, "models") vector_storepath = os.path.join(root_path, "vector_store") +LOGDIR = os.path.join(root_path, "logs") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -16,8 +17,8 @@ llm_model_config = { } LLM_MODEL = "vicuna-13b" - - +LIMIT_MODEL_CONCURRENCY = 5 +MAX_POSITION_EMBEDDINGS = 2048 vicuna_model_server = "http://192.168.31.114:8000/" diff --git a/pilot/server/embdserver.py b/pilot/server/embdserver.py new file mode 100644 index 000000000..97206f2d5 --- /dev/null +++ b/pilot/server/embdserver.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 9fa5ddbee..91aba5dec 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -2,8 +2,12 @@ # -*- coding: utf-8 -*- import uvicorn +import asyncio +import json from typing import Optional, List -from fastapi import FastAPI +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +from fastchat.serve.inference import generate_stream from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings from fastchat.serve.inference import load_model @@ -11,6 +15,11 @@ from pilot.model.loader import ModerLoader from pilot.configs.model_config import * model_path = llm_model_config[LLM_MODEL] + + +global_counter = 0 +model_semaphore = None + # ml = ModerLoader(model_path=model_path) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) @@ -27,6 +36,47 @@ class PromptRequest(BaseModel): class EmbeddingRequest(BaseModel): prompt: str +def release_model_semaphore(): + model_semaphore.release() + + +def generate_stream_gate(params): + try: + for output in generate_stream( + model, + tokenizer, + params, + DEVICE, + MAX_POSITION_EMBEDDINGS, + 2, + ): + ret = { + "text": output, + "error_code": 0, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError: + ret = { + "text": "**GPU OutOfMemory, Please Refresh.**", + "error_code": 0 + } + yield json.dumps(ret).encode() + b"\0" + + +@app.post("/generate_stream") +async def api_generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() + + generator = generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(generator, background=background_tasks) @app.post("/generate") def generate(prompt_request: PromptRequest): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py new file mode 100644 index 000000000..147af8b3d --- /dev/null +++ b/pilot/server/webserver.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import uuid +import time +import gradio as gr +import datetime + +from pilot.configs.model_config import LOGDIR + +from pilot.conversation import ( + get_default_conv_template, + compute_skip_echo_len, + SeparatorStyle +) + +from fastchat.utils import ( + build_logger, + server_error_msg, + violates_moderation, + moderation_msg +) + +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_css import code_highlight_css + +logger = build_logger("webserver", "webserver.log") +headers = {"User-Agent": "dbgpt Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=True) + +enable_moderation = False +models = [] + +priority = { + "vicuna-13b": "aaa" +} + +def set_global_vars(enable_moderation_, models_): + global enable_moderation, models + enable_moderation = enable_moderation_ + models = models_ + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def regenerate(state, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + state.skip_next = False + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = None + return (state, [], "") + (disable_btn,) * 5 + +def add_text(state, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}") + + if state is None: + state = get_default_conv_template("vicuna").copy() + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 + + if enable_moderation: + flagged = violates_moderation(text) + if flagged: + logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 + text = text[:1536] # ? + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + +def post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + +def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + + model_name = model_selector + temperature = float(temperature) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.message) == state.offset + 2: + new_state = get_default_conv_template(model_name).copy() + new_state.conv_id = uuid.uuid4().hex + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + + # TODO \ No newline at end of file From 90dd34e25a3a761b6afff19312896641b4a2ffe2 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 20:49:10 +0800 Subject: [PATCH 07/17] impl gradio web --- pilot/server/webserver.py | 283 +++++++++++++++++++++++++++++++++++++- 1 file changed, 276 insertions(+), 7 deletions(-) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 147af8b3d..ae16ac9ba 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,13 +1,17 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import argparse import os import uuid +import json import time import gradio as gr import datetime +import requests +from urllib.parse import urljoin -from pilot.configs.model_config import LOGDIR +from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL from pilot.conversation import ( get_default_conv_template, @@ -39,10 +43,41 @@ priority = { "vicuna-13b": "aaa" } -def set_global_vars(enable_moderation_, models_): +def set_global_vars(enable_moderation_): global enable_moderation, models enable_moderation = enable_moderation_ - models = models_ + +def load_demo_single(url_params): + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update(value=model, visible=True) + + state = None + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return load_demo_single(url_params) def get_conv_log_filename(): t = datetime.datetime.now() @@ -94,11 +129,11 @@ def post_process_code(code): code = sep.join(blocks) return code -def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request): +def http_bot(state, temperature, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() - model_name = model_selector + model_name = LLM_MODEL temperature = float(temperature) max_new_tokens = int(max_new_tokens) @@ -106,7 +141,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return - if len(state.message) == state.offset + 2: + if len(state.messages) == state.offset + 2: new_state = get_default_conv_template(model_name).copy() new_state.conv_id = uuid.uuid4().hex new_state.append_message(new_state.roles[0], state.messages[-2][1]) @@ -114,4 +149,238 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req state = new_state - # TODO \ No newline at end of file + prompt = state.get_prompt() + skip_echo_len = compute_skip_echo_len(prompt) + + payload = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, + } + + logger.info(f"Request: \n {payload}") + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + response = requests.post( + url=urljoin(vicuna_model_server, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=60, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][skip_echo_len].strip() + output = post_process_code(output) + state.messages[-1][-1] = output + "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code): {data['error_code']}" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn + ) + + return + time.sleep(0.02) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn + ) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as flog: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + flog.write(json.dumps(data), + "\n") + +block_css = ( + code_highlight_css + + """ +pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ +} +#notice_markdown th { + display: none; +} + """ +) + + +def build_single_model_ui(): + + notice_markdown = """ + # DB-GPT + + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index + ,基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。 + """ + learn_more_markdown = """ + ### Licence + The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA + """ + + state = gr.State() + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + + chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + + with gr.Column(scale=2, min_width=50): + send_btn = gr.Button(value="" "发送", visible=False) + + with gr.Row(visible=False) as button_row: + regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False) + clear_btn = gr.Button(value="🗑️" "清理", interactive=False) + + with gr.Accordion("参数", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="最大输出Token数", + ) + + gr.Markdown(learn_more_markdown) + + + btn_list = [regenerate_btn, clear_btn] + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + + send_btn.click( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list + ) + + return state, chatbot, textbox, send_btn, button_row, parameter_row + + +def build_webdemo(): + with gr.Blocks( + title="数据库智能助手", + theme=gr.themes.Base(), + # theme=gr.themes.Monochrome(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + ( + state, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ) = build_single_model_ui() + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [ + state, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ], + _js=get_window_url_params, + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + return demo + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument( + "--model-list-mode", type=str, default="once", choices=["once", "reload"] + ) + parser.add_argument("--share", default=False, action="store_true") + parser.add_argument( + "--moderate", action="store_true", help="Enable content moderation" + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + set_global_vars(args.moderate) + + logger.info(args) + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200, + ) \ No newline at end of file From 65ead15b656fdcf1f486dcd32ec2bb1a107132fa Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 21:03:47 +0800 Subject: [PATCH 08/17] add darm mode --- pilot/server/webserver.py | 90 ++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index ae16ac9ba..db5eca86c 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -71,6 +71,10 @@ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); + gradioURL = window.location.href + if (!gradioURL.endsWith('?__theme=dark')) { + window.location.replace(gradioURL + '?__theme=dark'); + } return url_params; } """ @@ -248,8 +252,7 @@ def build_single_model_ui(): notice_markdown = """ # DB-GPT - [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index - ,基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。 + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。 """ learn_more_markdown = """ ### Licence @@ -259,23 +262,6 @@ def build_single_model_ui(): state = gr.State() notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") - - chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) - with gr.Row(): - with gr.Column(scale=20): - textbox = gr.Textbox( - show_label=False, - placeholder="Enter text and press ENTER", - visible=False, - ).style(container=False) - - with gr.Column(scale=2, min_width=50): - send_btn = gr.Button(value="" "发送", visible=False) - - with gr.Row(visible=False) as button_row: - regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False) - clear_btn = gr.Button(value="🗑️" "清理", interactive=False) - with gr.Accordion("参数", open=False, visible=False) as parameter_row: temperature = gr.Slider( minimum=0.0, @@ -295,41 +281,57 @@ def build_single_model_ui(): label="最大输出Token数", ) - gr.Markdown(learn_more_markdown) + chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + with gr.Column(scale=2, min_width=50): + send_btn = gr.Button(value="" "发送", visible=False) - btn_list = [regenerate_btn, clear_btn] - regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( - http_bot, - [state, temperature, max_output_tokens], - [state, chatbot] + btn_list, - ) - clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) - textbox.submit( - add_text, [state, textbox], [state, chatbot, textbox] + btn_list - ).then( - http_bot, - [state, temperature, max_output_tokens], - [state, chatbot] + btn_list, - ) + with gr.Row(visible=False) as button_row: + regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False) + clear_btn = gr.Button(value="🗑️" "清理", interactive=False) - send_btn.click( - add_text, [state, textbox], [state, chatbot, textbox] + btn_list - ).then( - http_bot, - [state, temperature, max_output_tokens], - [state, chatbot] + btn_list - ) + gr.Markdown(learn_more_markdown) - return state, chatbot, textbox, send_btn, button_row, parameter_row + btn_list = [regenerate_btn, clear_btn] + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + + send_btn.click( + add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, max_output_tokens], + [state, chatbot] + btn_list + ) + + return state, chatbot, textbox, send_btn, button_row, parameter_row def build_webdemo(): with gr.Blocks( title="数据库智能助手", - theme=gr.themes.Base(), - # theme=gr.themes.Monochrome(), + # theme=gr.themes.Base(), + theme=gr.themes.Default(), css=block_css, ) as demo: url_params = gr.JSON(visible=False) From 905f14cf7c3d702f84e54f1d1e4112e230cc773e Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 21:21:42 +0800 Subject: [PATCH 09/17] add debug log --- pilot/server/vicuna_server.py | 4 ++-- pilot/server/webserver.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 91aba5dec..568503d6c 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -48,8 +48,8 @@ def generate_stream_gate(params): params, DEVICE, MAX_POSITION_EMBEDDINGS, - 2, ): + print("output: ", output) ret = { "text": output, "error_code": 0, @@ -68,7 +68,7 @@ async def api_generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 params = await request.json() - + print(model, tokenizer, params, DEVICE) if model_semaphore is None: model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) await model_semaphore.acquire() diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index db5eca86c..7b37895a1 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -156,12 +156,13 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): prompt = state.get_prompt() skip_echo_len = compute_skip_echo_len(prompt) + logger.info(f"State: {state}") payload = { "model": model_name, "prompt": prompt, "temperature": temperature, "max_new_tokens": max_new_tokens, - "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, + "stop": state.sep, } logger.info(f"Request: \n {payload}") @@ -179,6 +180,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) + logger.info(f"Response: {data}") if data["error_code"] == 0: output = data["text"][skip_echo_len].strip() output = post_process_code(output) From f0e17ed8f1205cbcb5678d3dd5118d89ca1ad1b3 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 21:27:55 +0800 Subject: [PATCH 10/17] handler request sync --- pilot/server/vicuna_server.py | 9 ++++----- pilot/server/webserver.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 568503d6c..00dc87057 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -64,14 +64,13 @@ def generate_stream_gate(params): @app.post("/generate_stream") -async def api_generate_stream(request: Request): +def api_generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 - params = await request.json() + params = request.json() print(model, tokenizer, params, DEVICE) - if model_semaphore is None: - model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) - await model_semaphore.acquire() + # if model_semaphore is None: + # model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) generator = generate_stream_gate(params) background_tasks = BackgroundTasks() diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 7b37895a1..1a2ea0665 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -254,7 +254,7 @@ def build_single_model_ui(): notice_markdown = """ # DB-GPT - [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它似乎是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施 DB-GPT 有任何具体问题,请告诉我,我会尽力帮助您。 + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。 """ learn_more_markdown = """ ### Licence From eef244fe92151356fdc3464eda95e0bb97d54ccc Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 21:54:49 +0800 Subject: [PATCH 11/17] rebuild params --- pilot/server/vicuna_server.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 00dc87057..61a651fe7 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -33,6 +33,13 @@ class PromptRequest(BaseModel): stop: Optional[List[str]] = None +class StreamRequest(BaseModel): + model: str + prompt: str + temperature: float + max_new_tokens: int + stop: str + class EmbeddingRequest(BaseModel): prompt: str @@ -64,10 +71,16 @@ def generate_stream_gate(params): @app.post("/generate_stream") -def api_generate_stream(request: Request): +def api_generate_stream(request: StreamRequest): global model_semaphore, global_counter global_counter += 1 - params = request.json() + params = { + "prompt": request.prompt, + "model": request.model, + "temperature": request.temperature, + "max_new_tokens": request.max_new_tokens, + "stop": request.stop + } print(model, tokenizer, params, DEVICE) # if model_semaphore is None: # model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) From 7af5433ae34034b58a897825433a52b8541eb5c5 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 22:07:03 +0800 Subject: [PATCH 12/17] test and fix --- pilot/server/vicuna_server.py | 6 +++--- pilot/server/webserver.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 61a651fe7..6c18f9cc6 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -86,9 +86,9 @@ def api_generate_stream(request: StreamRequest): # model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) generator = generate_stream_gate(params) - background_tasks = BackgroundTasks() - background_tasks.add_task(release_model_semaphore) - return StreamingResponse(generator, background=background_tasks) + # background_tasks = BackgroundTasks() + # background_tasks.add_task(release_model_semaphore) + return StreamingResponse(generator) @app.post("/generate") def generate(prompt_request: PromptRequest): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 1a2ea0665..00f71ae3c 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -254,7 +254,7 @@ def build_single_model_ui(): notice_markdown = """ # DB-GPT - [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强, 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。 + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna-13b作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。 """ learn_more_markdown = """ ### Licence From 4243211220df53d4ac2e77867e63566c600a548f Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 22:24:06 +0800 Subject: [PATCH 13/17] test --- pilot/conversation.py | 8 ++++---- pilot/server/vicuna_server.py | 31 ++++++++++++++++--------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/pilot/conversation.py b/pilot/conversation.py index 4db5d9548..1ee142762 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -89,8 +89,8 @@ class Conversation: conv_one_shot = Conversation( - system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", + system="A chat between a curious human and an artificial intelligence assistant." + "The assistant gives helpful, detailed and polite answers to the human's questions. ", roles=("Human", "Assistant"), messages=( ( @@ -123,8 +123,8 @@ conv_one_shot = Conversation( ) conv_vicuna_v1 = Conversation( - system = "A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", + system = "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed and polite answers to the user's questions. ", roles=("USER", "ASSISTANT"), messages=(), offset=0, diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 6c18f9cc6..3012d7ef1 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -22,8 +22,14 @@ model_semaphore = None # ml = ModerLoader(model_path=model_path) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) - model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) + +class ModelWorker: + def __init__(self): + pass + + # TODO + app = FastAPI() class PromptRequest(BaseModel): @@ -71,24 +77,19 @@ def generate_stream_gate(params): @app.post("/generate_stream") -def api_generate_stream(request: StreamRequest): +async def api_generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 - params = { - "prompt": request.prompt, - "model": request.model, - "temperature": request.temperature, - "max_new_tokens": request.max_new_tokens, - "stop": request.stop - } + params = await request.json() print(model, tokenizer, params, DEVICE) - # if model_semaphore is None: - # model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) - + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() + generator = generate_stream_gate(params) - # background_tasks = BackgroundTasks() - # background_tasks.add_task(release_model_semaphore) - return StreamingResponse(generator) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(generator, background=background_tasks) @app.post("/generate") def generate(prompt_request: PromptRequest): From c34e7224122c3cd74af09b01c2cfe5767a8b7d66 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 23:16:10 +0800 Subject: [PATCH 14/17] version align --- pilot/configs/model_config.py | 2 +- pilot/server/webserver.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index cf26ff7e1..249979477 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -19,7 +19,7 @@ llm_model_config = { LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 -vicuna_model_server = "http://192.168.31.114:8000/" +vicuna_model_server = "http://192.168.31.114:21002" # Load model config diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 00f71ae3c..8d4cb52ad 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -109,13 +109,15 @@ def add_text(state, text, request: gr.Request): if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 - + + # TODO 根据ChatGPT来进行模型评估 if enable_moderation: flagged = violates_moderation(text) if flagged: logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 + text = text[:1536] # ? state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) @@ -146,6 +148,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): return if len(state.messages) == state.offset + 2: + # 第一轮对话 new_state = get_default_conv_template(model_name).copy() new_state.conv_id = uuid.uuid4().hex new_state.append_message(new_state.roles[0], state.messages[-2][1]) @@ -162,7 +165,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): "prompt": prompt, "temperature": temperature, "max_new_tokens": max_new_tokens, - "stop": state.sep, + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, } logger.info(f"Request: \n {payload}") @@ -171,11 +174,11 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): try: response = requests.post( - url=urljoin(vicuna_model_server, "generate_stream"), + url=urljoin(vicuna_model_server, "worker_generate_stream"), headers=headers, json=payload, stream=True, - timeout=60, + timeout=20, ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: From 7861fc28cefa61a5f06187b253302574223c29b0 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 23:58:32 +0800 Subject: [PATCH 15/17] fix problem --- environment.yml | 1 + pilot/conversation.py | 26 ++----- pilot/server/vicuna_server.py | 1 + pilot/server/webserver.py | 137 ++++++++++++---------------------- requirements.txt | 3 +- 5 files changed, 60 insertions(+), 108 deletions(-) diff --git a/environment.yml b/environment.yml index 3ec4dfd98..3ba070e56 100644 --- a/environment.yml +++ b/environment.yml @@ -60,3 +60,4 @@ dependencies: - gradio==3.24.1 - gradio-client==0.0.8 - wandb + - fschat=0.1.10 diff --git a/pilot/conversation.py b/pilot/conversation.py index 1ee142762..3a702c615 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -3,7 +3,7 @@ import dataclasses from enum import auto, Enum -from typing import List, Tuple, Any +from typing import List, Any class SeparatorStyle(Enum): @@ -29,12 +29,12 @@ class Conversation: def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: - ret = self.system + ret = self.system + self.sep for role, message in self.messages: if message: - ret += self.sep + " " + role + ": " + message + ret += role + ": " + message + self.sep else: - ret += self.sep + " " + role + ":" + ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: @@ -56,7 +56,7 @@ class Conversation: def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): + for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: ret.append([msg, None]) else: @@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation( sep2="", ) -conv_template = { +default_conversation = conv_one_shot + +conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1 } - - -def get_default_conv_template(model_name: str = "vicuna-13b"): - model_name = model_name.lower() - if "vicuna" in model_name: - return conv_vicuna_v1 - return conv_one_shot - - -def compute_skip_echo_len(prompt): - skip_echo_len = len(prompt) + 1 - prompt.count("") * 3 - return skip_echo_len \ No newline at end of file diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 3012d7ef1..f664410e8 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -82,6 +82,7 @@ async def api_generate_stream(request: Request): global_counter += 1 params = await request.json() print(model, tokenizer, params, DEVICE) + if model_semaphore is None: model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) await model_semaphore.acquire() diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8d4cb52ad..8bd977cbb 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -14,8 +14,8 @@ from urllib.parse import urljoin from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL from pilot.conversation import ( - get_default_conv_template, - compute_skip_echo_len, + default_conversation, + conv_templates, SeparatorStyle ) @@ -43,29 +43,6 @@ priority = { "vicuna-13b": "aaa" } -def set_global_vars(enable_moderation_): - global enable_moderation, models - enable_moderation = enable_moderation_ - -def load_demo_single(url_params): - dropdown_update = gr.Dropdown.update(visible=True) - if "model" in url_params: - model = url_params["model"] - if model in models: - dropdown_update = gr.Dropdown.update(value=model, visible=True) - - state = None - return ( - state, - dropdown_update, - gr.Chatbot.update(visible=True), - gr.Textbox.update(visible=True), - gr.Button.update(visible=True), - gr.Row.update(visible=True), - gr.Accordion.update(visible=True), - ) - - get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); @@ -78,10 +55,24 @@ function() { return url_params; } """ - def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") - return load_demo_single(url_params) + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return (state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True)) def get_conv_log_filename(): t = datetime.datetime.now() @@ -100,29 +91,23 @@ def clear_history(request: gr.Request): state = None return (state, [], "") + (disable_btn,) * 5 -def add_text(state, text, request: gr.Request): - logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}") - if state is None: - state = get_default_conv_template("vicuna").copy() - +def add_text(state, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 - - # TODO 根据ChatGPT来进行模型评估 - if enable_moderation: + if args.moderate: flagged = violates_moderation(text) if flagged: - logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") state.skip_next = True - return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 + return (state, state.to_gradio_chatbot(), moderation_msg) + ( + no_change_btn,) * 5 - text = text[:1536] # ? - state.append_message(state.roles[0], text) + text = text[:1536] # Hard cut-off + state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False - return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def post_process_code(code): @@ -136,104 +121,80 @@ def post_process_code(code): return code def http_bot(state, temperature, max_new_tokens, request: gr.Request): - logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() - - model_name = LLM_MODEL - temperature = float(temperature) - max_new_tokens = int(max_new_tokens) + model_name = LLM_MODEL if state.skip_next: + # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: - # 第一轮对话 - new_state = get_default_conv_template(model_name).copy() + # First round of conversation + + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() new_state.conv_id = uuid.uuid4().hex new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state - prompt = state.get_prompt() - skip_echo_len = compute_skip_echo_len(prompt) + skip_echo_len = len(prompt.replace("", " ")) + 1 - logger.info(f"State: {state}") + # Make requests payload = { "model": model_name, "prompt": prompt, - "temperature": temperature, - "max_new_tokens": max_new_tokens, - "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None, + "temperature": float(temperature), + "max_new_tokens": int(max_new_tokens), + "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, } + logger.info(f"Requert: \n{payload}") - logger.info(f"Request: \n {payload}") state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: - response = requests.post( - url=urljoin(vicuna_model_server, "worker_generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) + # Stream output + response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"), + headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) - logger.info(f"Response: {data}") if data["error_code"] == 0: - output = data["text"][skip_echo_len].strip() + output = data["text"][skip_echo_len:].strip() output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: - output = data["text"] + f" (error_code): {data['error_code']}" + output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn - ) - + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return time.sleep(0.02) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn - ) + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - + finish_tstamp = time.time() logger.info(f"{output}") - with open(get_conv_log_filename(), "a") as flog: + with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, - "gen_params": { - "temperature": temperature, - "max_new_tokens": max_new_tokens, - }, "start": round(start_tstamp, 4), - "finish": round(finish_tstamp, 4), + "finish": round(start_tstamp, 4), "state": state.dict(), "ip": request.client.host, } - flog.write(json.dumps(data), + "\n") + fout.write(json.dumps(data) + "\n") block_css = ( code_highlight_css @@ -382,8 +343,6 @@ if __name__ == "__main__": args = parser.parse_args() logger.info(f"args: {args}") - set_global_vars(args.moderate) - logger.info(args) demo = build_webdemo() demo.queue( diff --git a/requirements.txt b/requirements.txt index dd7bf5189..50354b3b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,4 +49,5 @@ umap-learn notebook gradio==3.24.1 gradio-client==0.0.8 -wandb \ No newline at end of file +wandb +fschat=0.1.10 \ No newline at end of file From ed9b62d9d3681871262a7dd52c5ab9b5f6f3e8f5 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 23:59:11 +0800 Subject: [PATCH 16/17] fix problem --- pilot/configs/model_config.py | 2 +- pilot/server/webserver.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 249979477..80b1dabe4 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -19,7 +19,7 @@ llm_model_config = { LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 -vicuna_model_server = "http://192.168.31.114:21002" +vicuna_model_server = "http://192.168.31.114:8000" # Load model config diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8bd977cbb..19ee4b697 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -157,7 +157,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request): try: # Stream output - response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"), + response = requests.post(urljoin(vicuna_model_server, "generate_stream"), headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: From 943cffa3a9cb09ed10cdca8afde49d0a1a2b2c92 Mon Sep 17 00:00:00 2001 From: csunny Date: Mon, 1 May 2023 00:12:03 +0800 Subject: [PATCH 17/17] add prompt --- pilot/conversation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pilot/conversation.py b/pilot/conversation.py index 3a702c615..8e172e7dd 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -89,8 +89,8 @@ class Conversation: conv_one_shot = Conversation( - system="A chat between a curious human and an artificial intelligence assistant." - "The assistant gives helpful, detailed and polite answers to the human's questions. ", + system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " + "The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", roles=("Human", "Assistant"), messages=( ( @@ -123,8 +123,8 @@ conv_one_shot = Conversation( ) conv_vicuna_v1 = Conversation( - system = "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed and polite answers to the user's questions. ", + system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. " + "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", roles=("USER", "ASSISTANT"), messages=(), offset=0,