version align

This commit is contained in:
csunny 2023-04-30 23:16:10 +08:00
parent 86909045fc
commit 3f13fe3a7f

View File

@ -109,13 +109,15 @@ def add_text(state, text, request: gr.Request):
if len(text) <= 0: if len(text) <= 0:
state.skip_next = True state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
# TODO 根据ChatGPT来进行模型评估
if enable_moderation: if enable_moderation:
flagged = violates_moderation(text) flagged = violates_moderation(text)
if flagged: if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
text = text[:1536] # ? text = text[:1536] # ?
state.append_message(state.roles[0], text) state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None) state.append_message(state.roles[1], None)
@ -146,6 +148,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
return return
if len(state.messages) == state.offset + 2: if len(state.messages) == state.offset + 2:
# 第一轮对话
new_state = get_default_conv_template(model_name).copy() new_state = get_default_conv_template(model_name).copy()
new_state.conv_id = uuid.uuid4().hex new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[0], state.messages[-2][1])
@ -162,7 +165,7 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
"prompt": prompt, "prompt": prompt,
"temperature": temperature, "temperature": temperature,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"stop": state.sep, "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else None,
} }
logger.info(f"Request: \n {payload}") logger.info(f"Request: \n {payload}")
@ -171,11 +174,11 @@ def http_bot(state, temperature, max_new_tokens, request: gr.Request):
try: try:
response = requests.post( response = requests.post(
url=urljoin(vicuna_model_server, "generate_stream"), url=urljoin(vicuna_model_server, "worker_generate_stream"),
headers=headers, headers=headers,
json=payload, json=payload,
stream=True, stream=True,
timeout=60, timeout=20,
) )
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk: