mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 07:34:07 +00:00
commit
66771eedbd
@ -11,7 +11,7 @@ Overall, it appears to be a sophisticated and innovative tool for working with d
|
||||
1. Run model server
|
||||
```
|
||||
cd pilot/server
|
||||
uvicorn vicuna_server:app --host 0.0.0.0
|
||||
python vicuna_server.py
|
||||
```
|
||||
|
||||
2. Run gradio webui
|
||||
|
@ -60,3 +60,4 @@ dependencies:
|
||||
- gradio==3.24.1
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
||||
- fschat=0.1.10
|
||||
|
@ -7,6 +7,7 @@ import os
|
||||
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
model_path = os.path.join(root_path, "models")
|
||||
vector_storepath = os.path.join(root_path, "vector_store")
|
||||
LOGDIR = os.path.join(root_path, "logs")
|
||||
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@ -16,9 +17,9 @@ llm_model_config = {
|
||||
}
|
||||
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
|
||||
|
||||
vicuna_model_server = "http://192.168.31.114:8000/"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
vicuna_model_server = "http://192.168.31.114:8000"
|
||||
|
||||
|
||||
# Load model config
|
||||
|
141
pilot/conversation.py
Normal file
141
pilot/conversation.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""This class keeps all conversation history. """
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
# Used for gradio server
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ":" + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id
|
||||
}
|
||||
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
(
|
||||
"Human",
|
||||
"What are the key differences between mysql and postgres?",
|
||||
),
|
||||
(
|
||||
"Assistant",
|
||||
"MySQL and PostgreSQL are both popular open-source relational database management systems (RDBMS) "
|
||||
"that have many similarities but also some differences. Here are some key differences: \n"
|
||||
"1. Data Types: PostgreSQL has a more extensive set of data types, "
|
||||
"including support for array, hstore, JSON, and XML, whereas MySQL has a more limited set.\n"
|
||||
"2. ACID compliance: Both MySQL and PostgreSQL support ACID compliance (Atomicity, Consistency, Isolation, Durability), "
|
||||
"but PostgreSQL is generally considered to be more strict in enforcing it.\n"
|
||||
"3. Replication: MySQL has a built-in replication feature, which allows you to replicate data across multiple servers,"
|
||||
"whereas PostgreSQL has a similar feature, but it is not as mature as MySQL's.\n"
|
||||
"4. Performance: MySQL is generally considered to be faster and more efficient in handling large datasets, "
|
||||
"whereas PostgreSQL is known for its robustness and reliability.\n"
|
||||
"5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, "
|
||||
"whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n"
|
||||
|
||||
"Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. "
|
||||
"Both are excellent database management systems, and choosing the right one "
|
||||
"for your project requires careful consideration of your application's requirements, performance needs, and scalability."
|
||||
),
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###"
|
||||
)
|
||||
|
||||
conv_vicuna_v1 = Conversation(
|
||||
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1
|
||||
}
|
@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
stop_parameter = params.get("stop", None)
|
||||
|
||||
print(tokenizer.__dir__())
|
||||
if stop_parameter == tokenizer.eos_token:
|
||||
stop_parameter = None
|
||||
|
||||
stop_strings = []
|
||||
if isinstance(stop_parameter, str):
|
||||
stop_strings.append(stop_parameter)
|
||||
@ -43,33 +40,34 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_value
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
last_token_logits = logits[0][-1]
|
||||
|
||||
if temperature < 1e-4:
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=1)
|
||||
probs = torch.softmax(last_token_logits / temperature, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
|
||||
|
||||
output_ids.append(token)
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
for stop_str in stop_strings:
|
||||
pos = output.rfind(stop_str)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stoppped = True
|
||||
stopped = True
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
if stoppped:
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
||||
@ -81,7 +79,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
@torch.inference_mode()
|
||||
def get_embeddings(model, tokenizer, prompt):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
input_embeddings = model.get_input_embeddings()
|
||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
input_embeddings = model.get_input_embeddings().to(device)
|
||||
|
||||
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
|
||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||
return mean
|
||||
return mean.to(device)
|
||||
|
3
pilot/server/embdserver.py
Normal file
3
pilot/server/embdserver.py
Normal file
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
@ -1,16 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastchat.serve.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from fastchat.serve.inference import load_model
|
||||
from pilot.model.loader import ModerLoader
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
model_path = llm_model_config[LLM_MODEL]
|
||||
ml = ModerLoader(model_path=model_path)
|
||||
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
|
||||
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
|
||||
# ml = ModerLoader(model_path=model_path)
|
||||
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# TODO
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -21,9 +39,58 @@ class PromptRequest(BaseModel):
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class StreamRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
temperature: float
|
||||
max_new_tokens: int
|
||||
stop: str
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
|
||||
def generate_stream_gate(params):
|
||||
try:
|
||||
for output in generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
params,
|
||||
DEVICE,
|
||||
MAX_POSITION_EMBEDDINGS,
|
||||
):
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except torch.cuda.CudaError:
|
||||
ret = {
|
||||
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||
"error_code": 0
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
|
||||
@app.post("/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
params = await request.json()
|
||||
print(model, tokenizer, params, DEVICE)
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
||||
generator = generate_stream_gate(params)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
@ -45,4 +112,9 @@ def embeddings(prompt_request: EmbeddingRequest):
|
||||
params = {"prompt": prompt_request.prompt}
|
||||
print("Received prompt: ", params["prompt"])
|
||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||
return {"response": [float(x) for x in output]}
|
||||
return {"response": [float(x) for x in output]}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", log_level="info")
|
352
pilot/server/webserver.py
Normal file
352
pilot/server/webserver.py
Normal file
@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import gradio as gr
|
||||
import datetime
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
from fastchat.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
violates_moderation,
|
||||
moderation_msg
|
||||
)
|
||||
|
||||
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||
from fastchat.serve.gradio_css import code_highlight_css
|
||||
|
||||
logger = build_logger("webserver", "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
no_change_btn = gr.Button.update()
|
||||
enable_btn = gr.Button.update(interactive=True)
|
||||
disable_btn = gr.Button.update(interactive=True)
|
||||
|
||||
enable_moderation = False
|
||||
models = []
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
get_window_url_params = """
|
||||
function() {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
url_params = Object.fromEntries(params);
|
||||
console.log(url_params);
|
||||
gradioURL = window.location.href
|
||||
if (!gradioURL.endsWith('?__theme=dark')) {
|
||||
window.location.replace(gradioURL + '?__theme=dark');
|
||||
}
|
||||
return url_params;
|
||||
}
|
||||
"""
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if "model" in url_params:
|
||||
model = url_params["model"]
|
||||
if model in models:
|
||||
dropdown_update = gr.Dropdown.update(
|
||||
value=model, visible=True)
|
||||
|
||||
state = default_conversation.copy()
|
||||
return (state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True))
|
||||
|
||||
def get_conv_log_filename():
|
||||
t = datetime.datetime.now()
|
||||
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
||||
return name
|
||||
|
||||
|
||||
def regenerate(state, request: gr.Request):
|
||||
logger.info(f"regenerate. ip: {request.client.host}")
|
||||
state.messages[-1][-1] = None
|
||||
state.skip_next = False
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
|
||||
|
||||
def add_text(state, text, request: gr.Request):
|
||||
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
||||
if len(text) <= 0:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||
if args.moderate:
|
||||
flagged = violates_moderation(text)
|
||||
if flagged:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||
no_change_btn,) * 5
|
||||
|
||||
text = text[:1536] # Hard cut-off
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def post_process_code(code):
|
||||
sep = "\n```"
|
||||
if sep in code:
|
||||
blocks = code.split(sep)
|
||||
if len(blocks) % 2 == 1:
|
||||
for i in range(1, len(blocks), 2):
|
||||
blocks[i] = blocks[i].replace("\\_", "_")
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
if state.skip_next:
|
||||
# This generate call is skipped due to invalid inputs
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# First round of conversation
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
# Make requests
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(urljoin(vicuna_model_server, "generate_stream"),
|
||||
headers=headers, json=payload, stream=True, timeout=20)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
else:
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
time.sleep(0.02)
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
with open(get_conv_log_filename(), "a") as fout:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(start_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna-13b作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
||||
"""
|
||||
|
||||
state = gr.State()
|
||||
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||
|
||||
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
|
||||
temperature = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
|
||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=20):
|
||||
textbox = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="Enter text and press ENTER",
|
||||
visible=False,
|
||||
).style(container=False)
|
||||
|
||||
with gr.Column(scale=2, min_width=50):
|
||||
send_btn = gr.Button(value="" "发送", visible=False)
|
||||
|
||||
|
||||
with gr.Row(visible=False) as button_row:
|
||||
regenerate_btn = gr.Button(value="🔄" "重新生成", interactive=False)
|
||||
clear_btn = gr.Button(value="🗑️" "清理", interactive=False)
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
|
||||
textbox.submit(
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
send_btn.click(
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
)
|
||||
|
||||
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
||||
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
state,
|
||||
chatbot,
|
||||
textbox,
|
||||
send_btn,
|
||||
button_row,
|
||||
parameter_row,
|
||||
) = build_single_model_ui()
|
||||
|
||||
if args.model_list_mode == "once":
|
||||
demo.load(
|
||||
load_demo,
|
||||
[url_params],
|
||||
[
|
||||
state,
|
||||
chatbot,
|
||||
textbox,
|
||||
send_btn,
|
||||
button_row,
|
||||
parameter_row,
|
||||
],
|
||||
_js=get_window_url_params,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--moderate", action="store_true", help="Enable content moderation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||
)
|
@ -49,4 +49,5 @@ umap-learn
|
||||
notebook
|
||||
gradio==3.24.1
|
||||
gradio-client==0.0.8
|
||||
wandb
|
||||
wandb
|
||||
fschat=0.1.10
|
Loading…
Reference in New Issue
Block a user