fix problem

This commit is contained in:
csunny
2023-04-30 23:58:32 +08:00
parent c34e722412
commit 7861fc28ce
5 changed files with 60 additions and 108 deletions

View File

@@ -14,8 +14,8 @@ from urllib.parse import urljoin
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.conversation import (
get_default_conv_template,
compute_skip_echo_len,
default_conversation,
conv_templates,
SeparatorStyle
)
@@ -43,29 +43,6 @@ priority = {
"vicuna-13b": "aaa"
}
def set_global_vars(enable_moderation_):
global enable_moderation, models
enable_moderation = enable_moderation_
def load_demo_single(url_params):
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
@@ -78,10 +55,24 @@ function() {
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
return load_demo_single(url_params)
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)
state = default_conversation.copy()
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
def get_conv_log_filename():
t = datetime.datetime.now()
@@ -100,29 +91,23 @@ def clear_history(request: gr.Request):
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
if state is None:
state = get_default_conv_template("vicuna").copy()
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
# TODO 根据ChatGPT来进行模型评估
if enable_moderation:
if args.moderate:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,) * 5
text = text[:1536] # ?
state.append_message(state.roles[0], text)
text = text[:1536] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
@@ -136,104 +121,80 @@ def post_process_code(code):
return code
def http_bot(state, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = LLM_MODEL
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
model_name = LLM_MODEL
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# 第一轮对话
new_state = get_default_conv_template(model_name).copy()
# First round of conversation
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = compute_skip_echo_len(prompt)
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
logger.info(f"State: {state}")
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"Requert: \n{payload}")
logger.info(f"Request: \n {payload}")
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
response = requests.post(
url=urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
# Stream output
response = requests.post(urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
logger.info(f"Response: {data}")
if data["error_code"] == 0:
output = data["text"][skip_echo_len].strip()
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code): {data['error_code']}"
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn
)
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as flog:
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
flog.write(json.dumps(data), + "\n")
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
@@ -382,8 +343,6 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
set_global_vars(args.moderate)
logger.info(args)
demo = build_webdemo()
demo.queue(