mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 20:52:33 +00:00
fix problem
This commit is contained in:
parent
3f13fe3a7f
commit
4c7ba5021a
@ -60,3 +60,4 @@ dependencies:
|
||||
- gradio==3.24.1
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
||||
- fschat=0.1.10
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
@ -29,12 +29,12 @@ class Conversation:
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += self.sep + " " + role + ": " + message
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += self.sep + " " + role + ":"
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
@ -56,7 +56,7 @@ class Conversation:
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_template = {
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1
|
||||
}
|
||||
|
||||
|
||||
def get_default_conv_template(model_name: str = "vicuna-13b"):
|
||||
model_name = model_name.lower()
|
||||
if "vicuna" in model_name:
|
||||
return conv_vicuna_v1
|
||||
return conv_one_shot
|
||||
|
||||
|
||||
def compute_skip_echo_len(prompt):
|
||||
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
||||
return skip_echo_len
|
@ -82,6 +82,7 @@ async def api_generate_stream(request: Request):
|
||||
global_counter += 1
|
||||
params = await request.json()
|
||||
print(model, tokenizer, params, DEVICE)
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
@ -14,8 +14,8 @@ from urllib.parse import urljoin
|
||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
||||
|
||||
from pilot.conversation import (
|
||||
get_default_conv_template,
|
||||
compute_skip_echo_len,
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
@ -43,29 +43,6 @@ priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
def set_global_vars(enable_moderation_):
|
||||
global enable_moderation, models
|
||||
enable_moderation = enable_moderation_
|
||||
|
||||
def load_demo_single(url_params):
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if "model" in url_params:
|
||||
model = url_params["model"]
|
||||
if model in models:
|
||||
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
||||
|
||||
state = None
|
||||
return (
|
||||
state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True),
|
||||
)
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
function() {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
@ -78,10 +55,24 @@ function() {
|
||||
return url_params;
|
||||
}
|
||||
"""
|
||||
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
return load_demo_single(url_params)
|
||||
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if "model" in url_params:
|
||||
model = url_params["model"]
|
||||
if model in models:
|
||||
dropdown_update = gr.Dropdown.update(
|
||||
value=model, visible=True)
|
||||
|
||||
state = default_conversation.copy()
|
||||
return (state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True))
|
||||
|
||||
def get_conv_log_filename():
|
||||
t = datetime.datetime.now()
|
||||
@ -100,29 +91,23 @@ def clear_history(request: gr.Request):
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
|
||||
def add_text(state, text, request: gr.Request):
|
||||
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
|
||||
|
||||
if state is None:
|
||||
state = get_default_conv_template("vicuna").copy()
|
||||
|
||||
def add_text(state, text, request: gr.Request):
|
||||
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
||||
if len(text) <= 0:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||
|
||||
# TODO 根据ChatGPT来进行模型评估
|
||||
if enable_moderation:
|
||||
if args.moderate:
|
||||
flagged = violates_moderation(text)
|
||||
if flagged:
|
||||
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||
no_change_btn,) * 5
|
||||
|
||||
text = text[:1536] # ?
|
||||
state.append_message(state.roles[0], text)
|
||||
text = text[:1536] # Hard cut-off
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def post_process_code(code):
|
||||
@ -136,104 +121,80 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
|
||||
logger.info(f"http_bot. ip: {request.client.host}")
|
||||
start_tstamp = time.time()
|
||||
|
||||
model_name = LLM_MODEL
|
||||
temperature = float(temperature)
|
||||
max_new_tokens = int(max_new_tokens)
|
||||
model_name = LLM_MODEL
|
||||
|
||||
if state.skip_next:
|
||||
# This generate call is skipped due to invalid inputs
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# 第一轮对话
|
||||
new_state = get_default_conv_template(model_name).copy()
|
||||
# First round of conversation
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = compute_skip_echo_len(prompt)
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
logger.info(f"State: {state}")
|
||||
# Make requests
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
logger.info(f"Request: \n {payload}")
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_model_server, "worker_generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=20,
|
||||
)
|
||||
# Stream output
|
||||
response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"),
|
||||
headers=headers, json=payload, stream=True, timeout=20)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
logger.info(f"Response: {data}")
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
else:
|
||||
output = data["text"] + f" (error_code): {data['error_code']}"
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn
|
||||
)
|
||||
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
time.sleep(0.02)
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn
|
||||
)
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
with open(get_conv_log_filename(), "a") as flog:
|
||||
with open(get_conv_log_filename(), "a") as fout:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"gen_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(finish_tstamp, 4),
|
||||
"finish": round(start_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
flog.write(json.dumps(data), + "\n")
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
@ -382,8 +343,6 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
set_global_vars(args.moderate)
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
|
@ -49,4 +49,5 @@ umap-learn
|
||||
notebook
|
||||
gradio==3.24.1
|
||||
gradio-client==0.0.8
|
||||
wandb
|
||||
wandb
|
||||
fschat=0.1.10
|
Loading…
Reference in New Issue
Block a user