mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 18:33:52 +00:00
commit
66771eedbd
@ -11,7 +11,7 @@ Overall, it appears to be a sophisticated and innovative tool for working with d
|
|||||||
1. Run model server
|
1. Run model server
|
||||||
```
|
```
|
||||||
cd pilot/server
|
cd pilot/server
|
||||||
uvicorn vicuna_server:app --host 0.0.0.0
|
python vicuna_server.py
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run gradio webui
|
2. Run gradio webui
|
||||||
|
@ -60,3 +60,4 @@ dependencies:
|
|||||||
- gradio==3.24.1
|
- gradio==3.24.1
|
||||||
- gradio-client==0.0.8
|
- gradio-client==0.0.8
|
||||||
- wandb
|
- wandb
|
||||||
|
- fschat=0.1.10
|
||||||
|
@ -7,6 +7,7 @@ import os
|
|||||||
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
model_path = os.path.join(root_path, "models")
|
model_path = os.path.join(root_path, "models")
|
||||||
vector_storepath = os.path.join(root_path, "vector_store")
|
vector_storepath = os.path.join(root_path, "vector_store")
|
||||||
|
LOGDIR = os.path.join(root_path, "logs")
|
||||||
|
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -16,9 +17,9 @@ llm_model_config = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
|
MAX_POSITION_EMBEDDINGS = 2048
|
||||||
vicuna_model_server = "http://192.168.31.114:8000/"
|
vicuna_model_server = "http://192.168.31.114:8000"
|
||||||
|
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
|
141
pilot/conversation.py
Normal file
141
pilot/conversation.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from enum import auto, Enum
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
|
class SeparatorStyle(Enum):
|
||||||
|
|
||||||
|
SINGLE = auto()
|
||||||
|
TWO = auto()
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
"""This class keeps all conversation history. """
|
||||||
|
|
||||||
|
system: str
|
||||||
|
roles: List[str]
|
||||||
|
messages: List[List[str]]
|
||||||
|
offset: int
|
||||||
|
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||||
|
sep: str = "###"
|
||||||
|
sep2: str = None
|
||||||
|
|
||||||
|
# Used for gradio server
|
||||||
|
skip_next: bool = False
|
||||||
|
conv_id: Any = None
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
if self.sep_style == SeparatorStyle.SINGLE:
|
||||||
|
ret = self.system + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + self.sep
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
elif self.sep_style == SeparatorStyle.TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
ret = self.system + seps[0]
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
ret += role + ":" + message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
|
||||||
|
def append_message(self, role, message):
|
||||||
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
def to_gradio_chatbot(self):
|
||||||
|
ret = []
|
||||||
|
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||||
|
if i % 2 == 0:
|
||||||
|
ret.append([msg, None])
|
||||||
|
else:
|
||||||
|
ret[-1][-1] = msg
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(
|
||||||
|
system=self.system,
|
||||||
|
roles=self.roles,
|
||||||
|
messages=[[x, y] for x, y in self.messages],
|
||||||
|
offset=self.offset,
|
||||||
|
sep_style=self.sep_style,
|
||||||
|
sep=self.sep,
|
||||||
|
sep2=self.sep2,
|
||||||
|
conv_id=self.conv_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"roles": self.roles,
|
||||||
|
"messages": self.messages,
|
||||||
|
"offset": self.offset,
|
||||||
|
"sep": self.sep,
|
||||||
|
"sep2": self.sep2,
|
||||||
|
"conv_id": self.conv_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
conv_one_shot = Conversation(
|
||||||
|
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||||
|
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
|
||||||
|
roles=("Human", "Assistant"),
|
||||||
|
messages=(
|
||||||
|
(
|
||||||
|
"Human",
|
||||||
|
"What are the key differences between mysql and postgres?",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Assistant",
|
||||||
|
"MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) "
|
||||||
|
"that have many similarities but also some differences. Here are some key differences: \n"
|
||||||
|
"1. Data Types: PostgreSQL has a more extensive set of data types, "
|
||||||
|
"including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n"
|
||||||
|
"2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), "
|
||||||
|
"but PostgreSQL is generally considered to be more strict in enforcing it.\n"
|
||||||
|
"3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers,"
|
||||||
|
"whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n"
|
||||||
|
"4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, "
|
||||||
|
"whereas PostgreSQL is known for its robustness and reliability.\n"
|
||||||
|
"5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
|
||||||
|
"whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
|
||||||
|
|
||||||
|
"Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
|
||||||
|
"Both are excellent database management systems, and choosing the right one "
|
||||||
|
"for your project requires careful consideration of your application's requirements, performance needs, and scalability."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
offset=2,
|
||||||
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
|
sep="###"
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_vicuna_v1 = Conversation(
|
||||||
|
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||||
|
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||||
|
roles=("USER", "ASSISTANT"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.TWO,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_conversation = conv_one_shot
|
||||||
|
|
||||||
|
conv_templates = {
|
||||||
|
"conv_one_shot": conv_one_shot,
|
||||||
|
"vicuna_v1": conv_vicuna_v1
|
||||||
|
}
|
@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||||
stop_parameter = params.get("stop", None)
|
stop_parameter = params.get("stop", None)
|
||||||
|
|
||||||
print(tokenizer.__dir__())
|
|
||||||
if stop_parameter == tokenizer.eos_token:
|
if stop_parameter == tokenizer.eos_token:
|
||||||
stop_parameter = None
|
stop_parameter = None
|
||||||
|
|
||||||
stop_strings = []
|
stop_strings = []
|
||||||
if isinstance(stop_parameter, str):
|
if isinstance(stop_parameter, str):
|
||||||
stop_strings.append(stop_parameter)
|
stop_strings.append(stop_parameter)
|
||||||
@ -43,13 +40,14 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
logits = out.logits
|
logits = out.logits
|
||||||
past_key_values = out.past_key_value
|
past_key_values = out.past_key_values
|
||||||
|
|
||||||
last_token_logits = logits[0][-1]
|
last_token_logits = logits[0][-1]
|
||||||
|
|
||||||
if temperature < 1e-4:
|
if temperature < 1e-4:
|
||||||
token = int(torch.argmax(last_token_logits))
|
token = int(torch.argmax(last_token_logits))
|
||||||
else:
|
else:
|
||||||
probs = torch.softmax(last_token_logits / temperature, dim=1)
|
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||||
token = int(torch.multinomial(probs, num_samples=1))
|
token = int(torch.multinomial(probs, num_samples=1))
|
||||||
|
|
||||||
output_ids.append(token)
|
output_ids.append(token)
|
||||||
@ -64,12 +62,12 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
pos = output.rfind(stop_str)
|
pos = output.rfind(stop_str)
|
||||||
if pos != -1:
|
if pos != -1:
|
||||||
output = output[:pos]
|
output = output[:pos]
|
||||||
stoppped = True
|
stopped = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if stoppped:
|
if stopped:
|
||||||
break
|
break
|
||||||
|
|
||||||
del past_key_values
|
del past_key_values
|
||||||
@ -81,7 +79,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_embeddings(model, tokenizer, prompt):
|
def get_embeddings(model, tokenizer, prompt):
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
input_embeddings = model.get_input_embeddings()
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
input_embeddings = model.get_input_embeddings().to(device)
|
||||||
|
|
||||||
|
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
|
||||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||||
return mean
|
return mean.to(device)
|
||||||
|
3
pilot/server/embdserver.py
Normal file
3
pilot/server/embdserver.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
@ -1,16 +1,34 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request, BackgroundTasks
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastchat.serve.inference import generate_stream
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pilot.model.inference import generate_output, get_embeddings
|
from pilot.model.inference import generate_output, get_embeddings
|
||||||
|
from fastchat.serve.inference import load_model
|
||||||
from pilot.model.loader import ModerLoader
|
from pilot.model.loader import ModerLoader
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
|
||||||
model_path = llm_model_config[LLM_MODEL]
|
model_path = llm_model_config[LLM_MODEL]
|
||||||
ml = ModerLoader(model_path=model_path)
|
|
||||||
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
|
||||||
|
global_counter = 0
|
||||||
|
model_semaphore = None
|
||||||
|
|
||||||
|
# ml = ModerLoader(model_path=model_path)
|
||||||
|
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||||
|
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||||
|
|
||||||
|
class ModelWorker:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@ -21,9 +39,58 @@ class PromptRequest(BaseModel):
|
|||||||
stop: Optional[List[str]] = None
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StreamRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
temperature: float
|
||||||
|
max_new_tokens: int
|
||||||
|
stop: str
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
def release_model_semaphore():
|
||||||
|
model_semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stream_gate(params):
|
||||||
|
try:
|
||||||
|
for output in generate_stream(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
params,
|
||||||
|
DEVICE,
|
||||||
|
MAX_POSITION_EMBEDDINGS,
|
||||||
|
):
|
||||||
|
print("output: ", output)
|
||||||
|
ret = {
|
||||||
|
"text": output,
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
except torch.cuda.CudaError:
|
||||||
|
ret = {
|
||||||
|
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||||
|
"error_code": 0
|
||||||
|
}
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_stream")
|
||||||
|
async def api_generate_stream(request: Request):
|
||||||
|
global model_semaphore, global_counter
|
||||||
|
global_counter += 1
|
||||||
|
params = await request.json()
|
||||||
|
print(model, tokenizer, params, DEVICE)
|
||||||
|
|
||||||
|
if model_semaphore is None:
|
||||||
|
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||||
|
await model_semaphore.acquire()
|
||||||
|
|
||||||
|
generator = generate_stream_gate(params)
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
background_tasks.add_task(release_model_semaphore)
|
||||||
|
return StreamingResponse(generator, background=background_tasks)
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
def generate(prompt_request: PromptRequest):
|
def generate(prompt_request: PromptRequest):
|
||||||
@ -46,3 +113,8 @@ def embeddings(prompt_request: EmbeddingRequest):
|
|||||||
print("Received prompt: ", params["prompt"])
|
print("Received prompt: ", params["prompt"])
|
||||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||||
return {"response": [float(x) for x in output]}
|
return {"response": [float(x) for x in output]}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", log_level="info")
|
352
pilot/server/webserver.py
Normal file
352
pilot/server/webserver.py
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import gradio as gr
|
||||||
|
import datetime
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||||
|
|
||||||
|
from pilot.conversation import (
|
||||||
|
default_conversation,
|
||||||
|
conv_templates,
|
||||||
|
SeparatorStyle
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastchat.utils import (
|
||||||
|
build_logger,
|
||||||
|
server_error_msg,
|
||||||
|
violates_moderation,
|
||||||
|
moderation_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||||
|
from fastchat.serve.gradio_css import code_highlight_css
|
||||||
|
|
||||||
|
logger = build_logger("webserver", "webserver.log")
|
||||||
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
|
|
||||||
|
no_change_btn = gr.Button.update()
|
||||||
|
enable_btn = gr.Button.update(interactive=True)
|
||||||
|
disable_btn = gr.Button.update(interactive=True)
|
||||||
|
|
||||||
|
enable_moderation = False
|
||||||
|
models = []
|
||||||
|
|
||||||
|
priority = {
|
||||||
|
"vicuna-13b": "aaa"
|
||||||
|
}
|
||||||
|
|
||||||
|
get_window_url_params = """
|
||||||
|
function() {
|
||||||
|
const params = new URLSearchParams(window.location.search);
|
||||||
|
url_params = Object.fromEntries(params);
|
||||||
|
console.log(url_params);
|
||||||
|
gradioURL = window.location.href
|
||||||
|
if (!gradioURL.endsWith('?__theme=dark')) {
|
||||||
|
window.location.replace(gradioURL + '?__theme=dark');
|
||||||
|
}
|
||||||
|
return url_params;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
def load_demo(url_params, request: gr.Request):
|
||||||
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||||
|
|
||||||
|
dropdown_update = gr.Dropdown.update(visible=True)
|
||||||
|
if "model" in url_params:
|
||||||
|
model = url_params["model"]
|
||||||
|
if model in models:
|
||||||
|
dropdown_update = gr.Dropdown.update(
|
||||||
|
value=model, visible=True)
|
||||||
|
|
||||||
|
state = default_conversation.copy()
|
||||||
|
return (state,
|
||||||
|
dropdown_update,
|
||||||
|
gr.Chatbot.update(visible=True),
|
||||||
|
gr.Textbox.update(visible=True),
|
||||||
|
gr.Button.update(visible=True),
|
||||||
|
gr.Row.update(visible=True),
|
||||||
|
gr.Accordion.update(visible=True))
|
||||||
|
|
||||||
|
def get_conv_log_filename():
|
||||||
|
t = datetime.datetime.now()
|
||||||
|
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def regenerate(state, request: gr.Request):
|
||||||
|
logger.info(f"regenerate. ip: {request.client.host}")
|
||||||
|
state.messages[-1][-1] = None
|
||||||
|
state.skip_next = False
|
||||||
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
def clear_history(request: gr.Request):
|
||||||
|
logger.info(f"clear_history. ip: {request.client.host}")
|
||||||
|
state = None
|
||||||
|
return (state, [], "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
|
def add_text(state, text, request: gr.Request):
|
||||||
|
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
||||||
|
if len(text) <= 0:
|
||||||
|
state.skip_next = True
|
||||||
|
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||||
|
if args.moderate:
|
||||||
|
flagged = violates_moderation(text)
|
||||||
|
if flagged:
|
||||||
|
state.skip_next = True
|
||||||
|
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||||
|
no_change_btn,) * 5
|
||||||
|
|
||||||
|
text = text[:1536] # Hard cut-off
|
||||||
|
state.append_message(state.roles[0], text)
|
||||||
|
state.append_message(state.roles[1], None)
|
||||||
|
state.skip_next = False
|
||||||
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
def post_process_code(code):
|
||||||
|
sep = "\n```"
|
||||||
|
if sep in code:
|
||||||
|
blocks = code.split(sep)
|
||||||
|
if len(blocks) % 2 == 1:
|
||||||
|
for i in range(1, len(blocks), 2):
|
||||||
|
blocks[i] = blocks[i].replace("\\_", "_")
|
||||||
|
code = sep.join(blocks)
|
||||||
|
return code
|
||||||
|
|
||||||
|
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||||
|
start_tstamp = time.time()
|
||||||
|
model_name = LLM_MODEL
|
||||||
|
|
||||||
|
if state.skip_next:
|
||||||
|
# This generate call is skipped due to invalid inputs
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(state.messages) == state.offset + 2:
|
||||||
|
# First round of conversation
|
||||||
|
|
||||||
|
template_name = "conv_one_shot"
|
||||||
|
new_state = conv_templates[template_name].copy()
|
||||||
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
|
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
|
||||||
|
prompt = state.get_prompt()
|
||||||
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
|
# Make requests
|
||||||
|
payload = {
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": float(temperature),
|
||||||
|
"max_new_tokens": int(max_new_tokens),
|
||||||
|
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||||
|
}
|
||||||
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
|
||||||
|
state.messages[-1][-1] = "▌"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stream output
|
||||||
|
response = requests.post(urljoin(vicuna_model_server, "generate_stream"),
|
||||||
|
headers=headers, json=payload, stream=True, timeout=20)
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode())
|
||||||
|
if data["error_code"] == 0:
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
output = post_process_code(output)
|
||||||
|
state.messages[-1][-1] = output + "▌"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
else:
|
||||||
|
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||||
|
state.messages[-1][-1] = output
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
|
return
|
||||||
|
time.sleep(0.02)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
|
return
|
||||||
|
|
||||||
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
finish_tstamp = time.time()
|
||||||
|
logger.info(f"{output}")
|
||||||
|
|
||||||
|
with open(get_conv_log_filename(), "a") as fout:
|
||||||
|
data = {
|
||||||
|
"tstamp": round(finish_tstamp, 4),
|
||||||
|
"type": "chat",
|
||||||
|
"model": model_name,
|
||||||
|
"start": round(start_tstamp, 4),
|
||||||
|
"finish": round(start_tstamp, 4),
|
||||||
|
"state": state.dict(),
|
||||||
|
"ip": request.client.host,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(data) + "\n")
|
||||||
|
|
||||||
|
block_css = (
|
||||||
|
code_highlight_css
|
||||||
|
+ """
|
||||||
|
pre {
|
||||||
|
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||||
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||||
|
white-space: -pre-wrap; /* Opera 4-6 */
|
||||||
|
white-space: -o-pre-wrap; /* Opera 7 */
|
||||||
|
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||||
|
}
|
||||||
|
#notice_markdown th {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_single_model_ui():
|
||||||
|
|
||||||
|
notice_markdown = """
|
||||||
|
# DB-GPT
|
||||||
|
|
||||||
|
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna-13b作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||||
|
"""
|
||||||
|
learn_more_markdown = """
|
||||||
|
### Licence
|
||||||
|
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||||
|
"""
|
||||||
|
|
||||||
|
state = gr.State()
|
||||||
|
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||||
|
|
||||||
|
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
|
||||||
|
temperature = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.7,
|
||||||
|
step=0.1,
|
||||||
|
interactive=True,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_output_tokens = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=1024,
|
||||||
|
value=512,
|
||||||
|
step=64,
|
||||||
|
interactive=True,
|
||||||
|
label="最大输出Token数",
|
||||||
|
)
|
||||||
|
|
||||||
|
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=20):
|
||||||
|
textbox = gr.Textbox(
|
||||||
|
show_label=False,
|
||||||
|
placeholder="Enter text and press ENTER",
|
||||||
|
visible=False,
|
||||||
|
).style(container=False)
|
||||||
|
|
||||||
|
with gr.Column(scale=2, min_width=50):
|
||||||
|
send_btn = gr.Button(value="" "发送", visible=False)
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Row(visible=False) as button_row:
|
||||||
|
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
|
||||||
|
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
|
||||||
|
|
||||||
|
gr.Markdown(learn_more_markdown)
|
||||||
|
|
||||||
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
|
http_bot,
|
||||||
|
[state, temperature, max_output_tokens],
|
||||||
|
[state, chatbot] + btn_list,
|
||||||
|
)
|
||||||
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||||
|
|
||||||
|
textbox.submit(
|
||||||
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
|
).then(
|
||||||
|
http_bot,
|
||||||
|
[state, temperature, max_output_tokens],
|
||||||
|
[state, chatbot] + btn_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
send_btn.click(
|
||||||
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
|
).then(
|
||||||
|
http_bot,
|
||||||
|
[state, temperature, max_output_tokens],
|
||||||
|
[state, chatbot] + btn_list
|
||||||
|
)
|
||||||
|
|
||||||
|
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
||||||
|
|
||||||
|
|
||||||
|
def build_webdemo():
|
||||||
|
with gr.Blocks(
|
||||||
|
title="数据库智能助手",
|
||||||
|
# theme=gr.themes.Base(),
|
||||||
|
theme=gr.themes.Default(),
|
||||||
|
css=block_css,
|
||||||
|
) as demo:
|
||||||
|
url_params = gr.JSON(visible=False)
|
||||||
|
(
|
||||||
|
state,
|
||||||
|
chatbot,
|
||||||
|
textbox,
|
||||||
|
send_btn,
|
||||||
|
button_row,
|
||||||
|
parameter_row,
|
||||||
|
) = build_single_model_ui()
|
||||||
|
|
||||||
|
if args.model_list_mode == "once":
|
||||||
|
demo.load(
|
||||||
|
load_demo,
|
||||||
|
[url_params],
|
||||||
|
[
|
||||||
|
state,
|
||||||
|
chatbot,
|
||||||
|
textbox,
|
||||||
|
send_btn,
|
||||||
|
button_row,
|
||||||
|
parameter_row,
|
||||||
|
],
|
||||||
|
_js=get_window_url_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
||||||
|
return demo
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
|
parser.add_argument("--port", type=int)
|
||||||
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||||
|
)
|
||||||
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--moderate", action="store_true", help="Enable content moderation"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
logger.info(args)
|
||||||
|
demo = build_webdemo()
|
||||||
|
demo.queue(
|
||||||
|
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||||
|
).launch(
|
||||||
|
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||||
|
)
|
@ -50,3 +50,4 @@ notebook
|
|||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
gradio-client==0.0.8
|
gradio-client==0.0.8
|
||||||
wandb
|
wandb
|
||||||
|
fschat=0.1.10
|
Loading…
Reference in New Issue
Block a user