version align

This commit is contained in:
csunny 2023-04-30 23:16:10 +08:00
parent 86909045fc
commit 3f13fe3a7f

View File

@ -109,13 +109,15 @@ def add_text(state, text, request: gr.Request):
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
# TODO 根据ChatGPT来进行模型评估
if enable_moderation:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
text = text[:1536] # ?
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
@ -146,6 +148,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
return
if len(state.messages) == state.offset + 2:
# 第一轮对话
new_state = get_default_conv_template(model_name).copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
@ -162,7 +165,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": state.sep,
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
}
logger.info(f"Request: \n {payload}")
@ -171,11 +174,11 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
try:
response = requests.post(
url=urljoin(vicuna_model_server, "generate_stream"),
url=urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=60,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: