Merge pull request #3 from csunny/dev

First run version
This commit is contained in:
magic.chen 2023-05-01 00:14:00 +08:00 committed by GitHub
commit 66771eedbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 593 additions and 22 deletions

View File

@ -11,7 +11,7 @@ Overall, it appears to be a sophisticated and innovative tool for working with d
1. Run model server 1. Run model server
``` ```
cd pilot/server cd pilot/server
uvicorn vicuna_server:app --host 0.0.0.0 python vicuna_server.py
``` ```
2. Run gradio webui 2. Run gradio webui

View File

@ -60,3 +60,4 @@ dependencies:
- gradio==3.24.1 - gradio==3.24.1
- gradio-client==0.0.8 - gradio-client==0.0.8
- wandb - wandb
- fschat=0.1.10

View File

@ -7,6 +7,7 @@ import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models") model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store") vector_storepath = os.path.join(root_path, "vector_store")
LOGDIR = os.path.join(root_path, "logs")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@ -16,9 +17,9 @@ llm_model_config = {
} }
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048
vicuna_model_server = "http://192.168.31.114:8000/" vicuna_model_server = "http://192.168.31.114:8000"
# Load model config # Load model config

141
pilot/conversation.py Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import dataclasses
from enum import auto, Enum
from typing import List, Any
class SeparatorStyle(Enum):
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""This class keeps all conversation history. """
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
# Used for gradio server
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":" + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id
}
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
roles=("Human", "Assistant"),
messages=(
(
"Human",
"What are the key differences between mysql and postgres?",
),
(
"Assistant",
"MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) "
"that have many similarities but also some differences. Here are some key differences: \n"
"1. Data Types: PostgreSQL has a more extensive set of data types, "
"including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n"
"2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), "
"but PostgreSQL is generally considered to be more strict in enforcing it.\n"
"3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers,"
"whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n"
"4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, "
"whereas PostgreSQL is known for its robustness and reliability.\n"
"5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
"whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
"Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
"Both are excellent database management systems, and choosing the right one "
"for your project requires careful consideration of your application's requirements, performance needs, and scalability."
),
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###"
)
conv_vicuna_v1 = Conversation(
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1
}

View File

@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256)) max_new_tokens = int(params.get("max_new_tokens", 256))
stop_parameter = params.get("stop", None) stop_parameter = params.get("stop", None)
print(tokenizer.__dir__())
if stop_parameter == tokenizer.eos_token: if stop_parameter == tokenizer.eos_token:
stop_parameter = None stop_parameter = None
stop_strings = [] stop_strings = []
if isinstance(stop_parameter, str): if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter) stop_strings.append(stop_parameter)
@ -43,13 +40,14 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
past_key_values=past_key_values, past_key_values=past_key_values,
) )
logits = out.logits logits = out.logits
past_key_values = out.past_key_value past_key_values = out.past_key_values
last_token_logits = logits[0][-1] last_token_logits = logits[0][-1]
if temperature < 1e-4: if temperature < 1e-4:
token = int(torch.argmax(last_token_logits)) token = int(torch.argmax(last_token_logits))
else: else:
probs = torch.softmax(last_token_logits / temperature, dim=1) probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1)) token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token) output_ids.append(token)
@ -64,12 +62,12 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
pos = output.rfind(stop_str) pos = output.rfind(stop_str)
if pos != -1: if pos != -1:
output = output[:pos] output = output[:pos]
stoppped = True stopped = True
break break
else: else:
pass pass
if stoppped: if stopped:
break break
del past_key_values del past_key_values
@ -81,7 +79,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
@torch.inference_mode() @torch.inference_mode()
def get_embeddings(model, tokenizer, prompt): def get_embeddings(model, tokenizer, prompt):
input_ids = tokenizer(prompt).input_ids input_ids = tokenizer(prompt).input_ids
input_embeddings = model.get_input_embeddings() device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embeddings = input_embeddings(torch.LongTensor([input_ids])) input_embeddings = model.get_input_embeddings().to(device)
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
mean = torch.mean(embeddings[0], 0).cpu().detach() mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean return mean.to(device)

View File

@ -0,0 +1,3 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

View File

@ -1,16 +1,34 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import uvicorn
import asyncio
import json
from typing import Optional, List from typing import Optional, List
from fastapi import FastAPI from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastchat.serve.inference import generate_stream
from pydantic import BaseModel from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model
from pilot.model.loader import ModerLoader from pilot.model.loader import ModerLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL] model_path = llm_model_config[LLM_MODEL]
ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
global_counter = 0
model_semaphore = None
# ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
class ModelWorker:
def __init__(self):
pass
# TODO
app = FastAPI() app = FastAPI()
@ -21,9 +39,58 @@ class PromptRequest(BaseModel):
stop: Optional[List[str]] = None stop: Optional[List[str]] = None
class StreamRequest(BaseModel):
model: str
prompt: str
temperature: float
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel): class EmbeddingRequest(BaseModel):
prompt: str prompt: str
def release_model_semaphore():
model_semaphore.release()
def generate_stream_gate(params):
try:
for output in generate_stream(
model,
tokenizer,
params,
DEVICE,
MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: PromptRequest): def generate(prompt_request: PromptRequest):
@ -46,3 +113,8 @@ def embeddings(prompt_request: EmbeddingRequest):
print("Received prompt: ", params["prompt"]) print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"]) output = get_embeddings(model, tokenizer, params["prompt"])
return {"response": [float(x) for x in output]} return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_level="info")

352
pilot/server/webserver.py Normal file
View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import uuid
import json
import time
import gradio as gr
import datetime
import requests
from urllib.parse import urljoin
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.conversation import (
default_conversation,
conv_templates,
SeparatorStyle
)
from fastchat.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg
)
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
logger = build_logger("webserver", "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
priority = {
"vicuna-13b": "aaa"
}
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)
state = default_conversation.copy()
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,) * 5
text = text[:1536] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
start_tstamp = time.time()
model_name = LLM_MODEL
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"Requert: \n{payload}")
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(urljoin(vicuna_model_server, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def build_single_model_ui():
notice_markdown = """
# DB-GPT
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序它基于[FastChat](https://github.com/lm-sys/FastChat)并使用vicuna-13b作为基础模型此外此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强它可以进行SQL生成SQL诊断数据库知识问答等一系列的工作 总的来说它是一个用于数据库的复杂且创新的AI工具如果您对如何在工作中使用或实施DB-GPT有任何具体问题请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情
"""
learn_more_markdown = """
### Licence
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
"""
state = gr.State()
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="最大输出Token数",
)
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="" "发送", visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, temperature, max_output_tokens],
[state, chatbot] + btn_list
)
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui()
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params,
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument(
"--moderate", action="store_true", help="Enable content moderation"
)
args = parser.parse_args()
logger.info(f"args: {args}")
logger.info(args)
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
)

View File

@ -50,3 +50,4 @@ notebook
gradio==3.24.1 gradio==3.24.1
gradio-client==0.0.8 gradio-client==0.0.8
wandb wandb
fschat=0.1.10