fix problem

This commit is contained in:
csunny 2023-04-30 23:58:32 +08:00
parent 3f13fe3a7f
commit 4c7ba5021a
5 changed files with 60 additions and 108 deletions

View File

@ -60,3 +60,4 @@ dependencies:
- gradio==3.24.1
- gradio-client==0.0.8
- wandb
- fschat=0.1.10

View File

@ -3,7 +3,7 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
from typing import List, Any
class SeparatorStyle(Enum):
@ -29,12 +29,12 @@ class Conversation:
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += self.sep + " " + role + ": " + message
ret += role + ": " + message + self.sep
else:
ret += self.sep + " " + role + ":"
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
@ -56,7 +56,7 @@ class Conversation:
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation(
sep2="</s>",
)
conv_template = {
default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1
}
def get_default_conv_template(model_name: str = "vicuna-13b"):
model_name = model_name.lower()
if "vicuna" in model_name:
return conv_vicuna_v1
return conv_one_shot
def compute_skip_echo_len(prompt):
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
return skip_echo_len

View File

@ -82,6 +82,7 @@ async def api_generate_stream(request: Request):
global_counter += 1
params = await request.json()
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()

View File

@ -14,8 +14,8 @@ from urllib.parse import urljoin
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.conversation import (
get_default_conv_template,
compute_skip_echo_len,
default_conversation,
conv_templates,
SeparatorStyle
)
@ -43,29 +43,6 @@ priority = {
"vicuna-13b": "aaa"
}
def set_global_vars(enable_moderation_):
global enable_moderation, models
enable_moderation = enable_moderation_
def load_demo_single(url_params):
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
@ -78,10 +55,24 @@ function() {
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
return load_demo_single(url_params)
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)
state = default_conversation.copy()
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
def get_conv_log_filename():
t = datetime.datetime.now()
@ -100,29 +91,23 @@ def clear_history(request: gr.Request):
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
if state is None:
state = get_default_conv_template("vicuna").copy()
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
# TODO 根据ChatGPT来进行模型评估
if enable_moderation:
if args.moderate:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,) * 5
text = text[:1536] # ?
state.append_message(state.roles[0], text)
text = text[:1536] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
@ -136,104 +121,80 @@ def post_process_code(code):
return code
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = LLM_MODEL
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
model_name = LLM_MODEL
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# 第一轮对话
new_state = get_default_conv_template(model_name).copy()
# First round of conversation
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = compute_skip_echo_len(prompt)
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
logger.info(f"State: {state}")
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"Requert: \n{payload}")
logger.info(f"Request: \n {payload}")
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
response = requests.post(
url=urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
# Stream output
response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
logger.info(f"Response: {data}")
if data["error_code"] == 0:
output = data["text"][skip_echo_len].strip()
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code): {data['error_code']}"
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as flog:
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
flog.write(json.dumps(data), + "\n")
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
@ -382,8 +343,6 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
set_global_vars(args.moderate)
logger.info(args)
demo = build_webdemo()
demo.queue(

View File

@ -49,4 +49,5 @@ umap-learn
notebook
gradio==3.24.1
gradio-client==0.0.8
wandb
wandb
fschat=0.1.10